From ec5c45fe8f27b95c699c81205f3d5accaee9ab2b Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 19 Apr 2025 21:19:22 +0000 Subject: [PATCH 1/2] feat: move to sqlglot --- sqlspec/adapters/adbc/driver.py | 128 +++---------- sqlspec/adapters/aiosqlite/driver.py | 66 +++---- sqlspec/adapters/asyncmy/driver.py | 58 +++--- sqlspec/adapters/asyncpg/driver.py | 107 +++++------ sqlspec/adapters/duckdb/driver.py | 78 ++++---- sqlspec/adapters/oracledb/driver.py | 118 +++++++----- sqlspec/adapters/psycopg/driver.py | 271 ++++++--------------------- sqlspec/adapters/sqlite/driver.py | 67 +++---- sqlspec/base.py | 244 +++++++++++++++++++----- 9 files changed, 515 insertions(+), 622 deletions(-) diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 309b93baf..9123926f7 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -53,80 +53,15 @@ def _with_cursor(self, connection: "Connection") -> Generator["DbapiCursor", Non with contextlib.suppress(Exception): cursor.close() # type: ignore[no-untyped-call] - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name or %(name)s) to positional parameters specified by `self.param_style` - if the input parameters are a dictionary. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters - (always a tuple or None if the input was a dictionary, otherwise the original type). - - Raises: - ValueError: If a named parameter in the SQL is not found in the dictionary - or if a parameter in the dictionary is not used in the SQL. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return self._process_sql_statement(sql), parameters - - processed_sql = "" - processed_params_list: list[Any] = [] - last_end = 0 - found_params: set[str] = set() - - for match in PARAM_REGEX.finditer(sql): - if match.group("dquote") is not None or match.group("squote") is not None: - # Skip placeholders within quotes - continue - - # Get name from whichever group matched - var_name = match.group("var_name_colon") or match.group("var_name_perc") - - if var_name is None: # Should not happen with the new regex structure - continue - - if var_name not in parameters: - placeholder = match.group(0) # Get the full matched placeholder - msg = f"Named parameter '{placeholder}' found in SQL but not provided in parameters dictionary." - raise ValueError(msg) - - # Append segment before the placeholder - processed_sql += sql[last_end : match.start()] - # Append the driver's positional placeholder - processed_sql += self.param_style - processed_params_list.append(parameters[var_name]) - found_params.add(var_name) - last_end = match.end() - - # Append the rest of the SQL string - processed_sql += sql[last_end:] - - # Check if all provided parameters were used - unused_params = set(parameters.keys()) - found_params - if unused_params: - msg = f"Parameters provided but not found in SQL: {unused_params}" - # Depending on desired strictness, this could be a warning or an error - # For now, let's raise an error for clarity - raise ValueError(msg) - - return self._process_sql_statement(processed_sql), tuple(processed_params_list) - def select( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -134,7 +69,7 @@ def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] results = cursor.fetchall() # pyright: ignore @@ -152,8 +87,10 @@ def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -161,7 +98,7 @@ def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -176,8 +113,10 @@ def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -185,7 +124,7 @@ def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -201,8 +140,10 @@ def select_value( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -210,7 +151,7 @@ def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -224,8 +165,10 @@ def select_value_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -233,7 +176,7 @@ def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -248,7 +191,9 @@ def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -256,7 +201,7 @@ def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -267,8 +212,10 @@ def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -276,7 +223,7 @@ def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) column_names: list[str] = [] with self._with_cursor(connection) as cursor: @@ -294,7 +241,9 @@ def execute_script( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -308,33 +257,6 @@ def execute_script( cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cast("str", cursor.statusmessage) if hasattr(cursor, "statusmessage") else "DONE" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] - def execute_script_returning( - self, - sql: str, - parameters: Optional["StatementParameterType"] = None, - /, - connection: Optional["Connection"] = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - column_names: list[str] = [] - - with self._with_cursor(connection) as cursor: - cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] - result = cursor.fetchall() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if len(result) == 0: # pyright: ignore[reportUnknownArgumentType] - return None - column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) # pyright: ignore[reportUnknownArgumentType] - return dict(zip(column_names, result[0])) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] - # --- Arrow Bulk Operations --- def select_arrow( # pyright: ignore[reportUnknownParameterType] @@ -342,7 +264,9 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> "ArrowTable": """Execute a SQL query and return results as an Apache Arrow Table. @@ -350,7 +274,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] The results of the query as an Apache Arrow Table. """ conn = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(conn) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index 06e4a5f61..dbfea15fa 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -33,42 +33,15 @@ async def _with_cursor(self, connection: "Connection") -> "AsyncGenerator[Cursor finally: await cursor.close() - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name) to positional parameters (?) for SQLite. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return sql, parameters - - # Convert named parameters to positional parameters - processed_sql = sql - processed_params: list[Any] = [] - for key, value in parameters.items(): - # Replace :key with ? in the SQL - processed_sql = processed_sql.replace(f":{key}", "?") - processed_params.append(value) - - return processed_sql, tuple(processed_params) - async def select( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -76,7 +49,7 @@ async def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] results = await cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -92,8 +65,10 @@ async def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -101,7 +76,7 @@ async def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -116,8 +91,10 @@ async def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -125,7 +102,7 @@ async def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -141,8 +118,10 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -150,7 +129,7 @@ async def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType] @@ -164,8 +143,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -173,8 +154,7 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType] @@ -189,7 +169,9 @@ async def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -197,7 +179,7 @@ async def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -208,8 +190,10 @@ async def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -217,7 +201,7 @@ async def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -234,7 +218,9 @@ async def execute_script( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -242,7 +228,7 @@ async def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -253,8 +239,10 @@ async def execute_script_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Execute a script and return result. @@ -262,7 +250,7 @@ async def execute_script_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index ff6561006..2d9e06f98 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -40,8 +40,10 @@ async def select( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -49,7 +51,7 @@ async def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) results = await cursor.fetchall() @@ -65,8 +67,10 @@ async def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -74,7 +78,7 @@ async def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) result = await cursor.fetchone() @@ -89,8 +93,10 @@ async def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -98,7 +104,7 @@ async def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) result = await cursor.fetchone() @@ -114,8 +120,10 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -123,7 +131,7 @@ async def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -140,8 +148,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -149,7 +159,7 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -168,7 +178,9 @@ async def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -176,7 +188,7 @@ async def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -187,8 +199,10 @@ async def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -196,7 +210,7 @@ async def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) column_names: list[str] = [] async with self._with_cursor(connection) as cursor: @@ -214,7 +228,9 @@ async def execute_script( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -222,34 +238,8 @@ async def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) return "DONE" - - async def execute_script_returning( - self, - sql: str, - parameters: Optional["StatementParameterType"] = None, - /, - connection: Optional["Connection"] = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - async with self._with_cursor(connection) as cursor: - await cursor.execute(sql, parameters) - result = await cursor.fetchone() - if result is None: - return None - column_names = [c[0] for c in cursor.description or []] - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) - return dict(zip(column_names, result)) diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index 48ad1f67b..edf3fc9ff 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -25,19 +25,15 @@ class AsyncpgDriver(AsyncDriverAdapterProtocol["AsyncpgConnection"]): def __init__(self, connection: "AsyncpgConnection") -> None: self.connection = connection - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Union[tuple[Any, ...], list[Any], dict[str, Any]]]": - sql, parameters = super()._process_sql_params(sql, parameters) - return sql, parameters if parameters is not None else () - async def select( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -46,12 +42,14 @@ async def select( parameters: Query parameters. connection: Optional connection to use. schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. Returns: List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else () results = await connection.fetch(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] if not results: @@ -65,8 +63,10 @@ async def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -75,13 +75,14 @@ async def select_one( parameters: Query parameters. connection: Optional connection to use. schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. Returns: The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else () result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] result = self.check_not_found(result) @@ -95,8 +96,10 @@ async def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -105,36 +108,47 @@ async def select_one_or_none( parameters: Query parameters. connection: Optional connection to use. schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. Returns: The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else () result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - result = self.check_not_found(result) + if result is None: + return None if schema_type is None: # Always return as dictionary - return dict(result.items()) # type: ignore[attr-defined] - return cast("ModelDTOT", schema_type(**dict(result.items()))) # type: ignore[attr-defined] + return dict(result.items()) + return cast("ModelDTOT", schema_type(**dict(result.items()))) async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. + Args: + sql: SQL statement. + parameters: Query parameters. + connection: Optional connection to use. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. + Returns: The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else () result = await connection.fetchval(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] result = self.check_not_found(result) if schema_type is None: @@ -146,8 +160,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -155,8 +171,8 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else () result = await connection.fetchval(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] if result is None: return None @@ -169,7 +185,9 @@ async def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -177,13 +195,14 @@ async def insert_update_delete( sql: SQL statement. parameters: Query parameters. connection: Optional connection to use. + **kwargs: Additional keyword arguments. Returns: Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else () status = await connection.execute(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] # AsyncPG returns a string like "INSERT 0 1" where the last number is the affected rows try: @@ -196,23 +215,26 @@ async def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Insert, update, or delete data from the database and return result. + """Insert, update, or delete data from the database and return the affected row. Args: sql: SQL statement. parameters: Query parameters. connection: Optional connection to use. schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. Returns: - The first row of results. + The affected row data as either a model instance or dictionary. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else () result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] if result is None: return None @@ -226,7 +248,9 @@ async def execute_script( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["AsyncpgConnection"] = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -234,41 +258,12 @@ async def execute_script( sql: SQL statement. parameters: Query parameters. connection: Optional connection to use. + **kwargs: Additional keyword arguments. Returns: Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + parameters = parameters if parameters is not None else () return await connection.execute(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - - async def execute_script_returning( - self, - sql: str, - parameters: Optional["StatementParameterType"] = None, - /, - connection: Optional["AsyncpgConnection"] = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Args: - sql: SQL statement. - parameters: Query parameters. - connection: Optional connection to use. - schema_type: Optional schema class for the result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - if result is None: - return None - if schema_type is None: - # Always return as dictionary - return dict(result.items()) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - return cast("ModelDTOT", schema_type(**dict(result.items()))) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType] diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index 6b01453c9..80677c574 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -49,11 +49,25 @@ def select( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": + """Fetch data from the database. + + Args: + sql: SQL statement. + parameters: Query parameters. + connection: Optional connection to use. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments. + + Returns: + List of row data as either model instances or dictionaries. + """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] results = cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -71,8 +85,10 @@ def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters) @@ -93,11 +109,13 @@ def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -115,8 +133,10 @@ def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[DuckDBPyConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters) @@ -134,8 +154,10 @@ def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[DuckDBPyConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters) @@ -153,10 +175,12 @@ def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, + **kwargs: Any, ) -> int: connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return getattr(cursor, "rowcount", -1) # pyright: ignore[reportUnknownMemberType] @@ -166,11 +190,13 @@ def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -182,44 +208,17 @@ def insert_update_delete_returning( # Always return dictionaries return dict(zip(column_names, result[0])) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name) to positional parameters (?) for DuckDB. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return sql, parameters - - # Convert named parameters to positional parameters - processed_sql = sql - processed_params: list[Any] = [] - for key, value in parameters.items(): - # Replace :key with ? in the SQL - processed_sql = processed_sql.replace(f":{key}", "?") - processed_params.append(value) - - return processed_sql, tuple(processed_params) - def execute_script( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["DuckDBPyConnection"] = None, + **kwargs: Any, ) -> str: connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cast("str", getattr(cursor, "statusmessage", "DONE")) # pyright: ignore[reportUnknownMemberType] @@ -231,7 +230,9 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[DuckDBPyConnection]" = None, + **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] """Execute a SQL query and return results as an Apache Arrow Table. @@ -239,9 +240,8 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] An Apache Arrow Table containing the query results. """ - conn = self._connection(connection) - processed_sql, processed_params = self._process_sql_params(sql, parameters) - - with self._with_cursor(conn) as cursor: - cursor.execute(processed_sql, processed_params) # pyright: ignore[reportUnknownMemberType] + connection = self._connection(connection) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + with self._with_cursor(connection) as cursor: + cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return cast("ArrowTable", cursor.fetch_arrow_table()) # pyright: ignore[reportUnknownMemberType] diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 910c62e9f..bd3751f93 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -43,16 +43,25 @@ def select( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. + Args: + sql: The SQL query string. + parameters: The parameters for the query (dict, tuple, list, or None). + connection: Optional connection override. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. + Returns: List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] results = cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -71,16 +80,25 @@ def select_one( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. + Args: + sql: The SQL query string. + parameters: The parameters for the query (dict, tuple, list, or None). + connection: Optional connection override. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. + Returns: The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -100,8 +118,10 @@ def select_one_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -109,7 +129,7 @@ def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -131,8 +151,10 @@ def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -140,7 +162,7 @@ def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -156,8 +178,10 @@ def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -165,7 +189,7 @@ def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -183,7 +207,9 @@ def insert_update_delete( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -191,7 +217,7 @@ def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -202,8 +228,10 @@ def insert_update_delete_returning( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -211,7 +239,7 @@ def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -233,7 +261,9 @@ def execute_script( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -241,7 +271,7 @@ def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -252,7 +282,9 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] """Execute a SQL query and return results as an Apache Arrow Table. @@ -261,7 +293,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) results = connection.fetch_df_all(sql, parameters) return cast("ArrowTable", ArrowTable.from_arrays(arrays=results.column_arrays(), names=results.column_names())) # pyright: ignore @@ -290,8 +322,10 @@ async def select( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -299,7 +333,7 @@ async def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -319,8 +353,10 @@ async def select_one( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -328,7 +364,7 @@ async def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -347,8 +383,10 @@ async def select_one_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -356,7 +394,7 @@ async def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -378,8 +416,10 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -387,7 +427,7 @@ async def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -403,8 +443,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -412,7 +454,7 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -430,7 +472,9 @@ async def insert_update_delete( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -438,7 +482,7 @@ async def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -449,8 +493,10 @@ async def insert_update_delete_returning( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -458,7 +504,7 @@ async def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] @@ -480,7 +526,9 @@ async def execute_script( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -488,49 +536,20 @@ async def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return str(cursor.rowcount) # pyright: ignore[reportUnknownMemberType] - async def execute_script_returning( - self, - sql: str, - parameters: "Optional[StatementParameterType]" = None, - /, - connection: "Optional[AsyncConnection]" = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - async with self._with_cursor(connection) as cursor: - await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] - result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - - if result is None: - return None - - # Get column names - column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType] - # Always return dictionaries - return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] - async def select_arrow( # pyright: ignore[reportUnknownParameterType] self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] """Execute a SQL query asynchronously and return results as an Apache Arrow Table. @@ -538,12 +557,13 @@ async def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: The SQL query string. parameters: Parameters for the query. connection: Optional connection override. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: An Apache Arrow Table containing the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) results = await connection.fetch_df_all(sql, parameters) return ArrowTable.from_arrays(arrays=results.column_arrays(), names=results.column_names()) # pyright: ignore diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index bc7c4fad9..09375d970 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -3,7 +3,7 @@ from psycopg.rows import dict_row -from sqlspec.base import PARAM_REGEX, AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol, T +from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol, T if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator @@ -33,75 +33,15 @@ def _with_cursor(connection: "Connection") -> "Generator[Any, None, None]": finally: cursor.close() - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name) to positional parameters (%s) - if the input parameters are a dictionary. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters - (always a tuple or None if the input was a dictionary, otherwise the original type). - - Raises: - ValueError: If a named parameter in the SQL is not found in the dictionary - or if a parameter in the dictionary is not used in the SQL. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return sql, parameters - - processed_sql = "" - processed_params_list: list[Any] = [] - last_end = 0 - found_params: set[str] = set() - - for match in PARAM_REGEX.finditer(sql): - if match.group("dquote") is not None or match.group("squote") is not None: - # Skip placeholders within quotes - continue - - var_name = match.group("var_name") - if var_name is None: # Should not happen with the regex, but safeguard - continue - - if var_name not in parameters: - msg = f"Named parameter ':{var_name}' found in SQL but not provided in parameters dictionary." - raise ValueError(msg) - - # Append segment before the placeholder + the driver's positional placeholder - processed_sql += sql[last_end : match.start("var_name") - 1] + "%s" - processed_params_list.append(parameters[var_name]) - found_params.add(var_name) - last_end = match.end("var_name") - - # Append the rest of the SQL string - processed_sql += sql[last_end:] - - # Check if all provided parameters were used - unused_params = set(parameters.keys()) - found_params - if unused_params: - msg = f"Parameters provided but not found in SQL: {unused_params}" - # Depending on desired strictness, this could be a warning or an error - # For now, let's raise an error for clarity - raise ValueError(msg) - - return processed_sql, tuple(processed_params_list) - def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, - connection: "Optional[Connection]" = None, + *, schema_type: "Optional[type[ModelDTOT]]" = None, + connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -109,7 +49,7 @@ def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) results = cursor.fetchall() @@ -125,8 +65,10 @@ def select_one( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -134,8 +76,7 @@ def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) row = cursor.fetchone() @@ -149,8 +90,10 @@ def select_one_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -158,8 +101,7 @@ def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) row = cursor.fetchone() @@ -174,8 +116,10 @@ def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -183,8 +127,7 @@ def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) row = cursor.fetchone() @@ -199,17 +142,18 @@ def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": - """Fetch a single value from the database. + """Fetch one row from the database. Returns: - The first value from the first row of results, or None if no results. + The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) row = cursor.fetchone() @@ -225,27 +169,30 @@ def insert_update_delete( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> int: - """Insert, update, or delete data from the database. + """Execute an INSERT, UPDATE, or DELETE query and return the number of affected rows. Returns: - Row count affected by the operation. + The number of rows affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) - return cursor.rowcount if hasattr(cursor, "rowcount") else -1 + return getattr(cursor, "rowcount", -1) # pyright: ignore[reportUnknownMemberType] def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -253,8 +200,7 @@ def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) result = cursor.fetchone() @@ -271,7 +217,9 @@ def execute_script( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -279,39 +227,11 @@ def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) return str(cursor.rowcount) - def execute_script_returning( - self, - sql: str, - parameters: "Optional[StatementParameterType]" = None, - /, - connection: "Optional[Connection]" = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - with self._with_cursor(connection) as cursor: - cursor.execute(sql, parameters) - result = cursor.fetchone() - - if result is None: - return None - - if schema_type is not None: - return cast("ModelDTOT", schema_type(**result)) # pyright: ignore[reportUnknownArgumentType] - return cast("dict[str, Any]", result) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] - class PsycopgAsyncDriver(AsyncDriverAdapterProtocol["AsyncConnection"]): """Psycopg Async Driver Adapter.""" @@ -331,75 +251,15 @@ async def _with_cursor(connection: "AsyncConnection") -> "AsyncGenerator[Any, No finally: await cursor.close() - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name) to positional parameters (%s) - if the input parameters are a dictionary. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters - (always a tuple or None if the input was a dictionary, otherwise the original type). - - Raises: - ValueError: If a named parameter in the SQL is not found in the dictionary - or if a parameter in the dictionary is not used in the SQL. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return sql, parameters - - processed_sql = "" - processed_params_list: list[Any] = [] - last_end = 0 - found_params: set[str] = set() - - for match in PARAM_REGEX.finditer(sql): - if match.group("dquote") is not None or match.group("squote") is not None: - # Skip placeholders within quotes - continue - - var_name = match.group("var_name") - if var_name is None: # Should not happen with the regex, but safeguard - continue - - if var_name not in parameters: - msg = f"Named parameter ':{var_name}' found in SQL but not provided in parameters dictionary." - raise ValueError(msg) - - # Append segment before the placeholder + the driver's positional placeholder - processed_sql += sql[last_end : match.start("var_name") - 1] + "%s" - processed_params_list.append(parameters[var_name]) - found_params.add(var_name) - last_end = match.end("var_name") - - # Append the rest of the SQL string - processed_sql += sql[last_end:] - - # Check if all provided parameters were used - unused_params = set(parameters.keys()) - found_params - if unused_params: - msg = f"Parameters provided but not found in SQL: {unused_params}" - # Depending on desired strictness, this could be a warning or an error - # For now, let's raise an error for clarity - raise ValueError(msg) - - return processed_sql, tuple(processed_params_list) - async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -407,7 +267,7 @@ async def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) results: list[Union[ModelDTOT, dict[str, Any]]] = [] async with self._with_cursor(connection) as cursor: @@ -424,8 +284,10 @@ async def select_one( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -433,7 +295,7 @@ async def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -448,8 +310,10 @@ async def select_one_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, - connection: "Optional[AsyncConnection]" = None, + *, schema_type: "Optional[type[ModelDTOT]]" = None, + connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -457,7 +321,7 @@ async def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -473,16 +337,18 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, - ) -> "Optional[Union[T, Any]]": + **kwargs: Any, + ) -> "Union[T, Any]": """Fetch a single value from the database. Returns: The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -498,8 +364,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -507,7 +375,7 @@ async def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -524,15 +392,17 @@ async def insert_update_delete( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> int: - """Insert, update, or delete data from the database. + """Execute an INSERT, UPDATE, or DELETE query and return the number of affected rows. Returns: - Row count affected by the operation. + The number of rows affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -547,8 +417,10 @@ async def insert_update_delete_returning( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -556,7 +428,7 @@ async def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) @@ -574,7 +446,9 @@ async def execute_script( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[AsyncConnection]" = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -582,35 +456,8 @@ async def execute_script( Status message for the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) return str(cursor.rowcount) - - async def execute_script_returning( - self, - sql: str, - parameters: "Optional[StatementParameterType]" = None, - /, - connection: "Optional[AsyncConnection]" = None, - schema_type: "Optional[type[ModelDTOT]]" = None, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Execute a script and return result. - - Returns: - The first row of results. - """ - connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - - async with self._with_cursor(connection) as cursor: - await cursor.execute(sql, parameters) - result = await cursor.fetchone() - - if result is None: - return None - - if schema_type is not None: - return cast("ModelDTOT", schema_type(**result)) # pyright: ignore[reportUnknownArgumentType] - return cast("dict[str, Any]", result) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 08e81f0a2..1e8b20bde 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -37,8 +37,10 @@ def select( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. @@ -46,7 +48,7 @@ def select( List of row data as either model instances or dictionaries. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] @@ -65,8 +67,10 @@ def select_one( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. @@ -74,7 +78,7 @@ def select_one( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] @@ -92,8 +96,10 @@ def select_one_or_none( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": """Fetch one row from the database. @@ -101,7 +107,7 @@ def select_one_or_none( The first row of the query results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] @@ -120,8 +126,10 @@ def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[T, Any]": """Fetch a single value from the database. @@ -129,7 +137,7 @@ def select_value( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] @@ -146,8 +154,10 @@ def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[Connection]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[T, Any]]": """Fetch a single value from the database. @@ -155,7 +165,7 @@ def select_value_or_none( The first value from the first row of results, or None if no results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] @@ -173,7 +183,9 @@ def insert_update_delete( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> int: """Insert, update, or delete data from the database. @@ -181,7 +193,7 @@ def insert_update_delete( Row count affected by the operation. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: @@ -195,8 +207,10 @@ def insert_update_delete_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. @@ -204,7 +218,7 @@ def insert_update_delete_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: @@ -219,41 +233,14 @@ def insert_update_delete_returning( return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) return dict(zip(column_names, result[0])) - def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None - ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Converts named parameters (:name) to positional parameters (?) for SQLite. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - - Returns: - A tuple containing the processed SQL string and the processed parameters. - """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly - return sql, parameters - - # Convert named parameters to positional parameters - processed_sql = sql - processed_params: list[Any] = [] - for key, value in parameters.items(): - # Replace :key with ? in the SQL - processed_sql = processed_sql.replace(f":{key}", "?") - processed_params.append(value) - - return processed_sql, tuple(processed_params) - def execute_script( self, sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, + **kwargs: Any, ) -> str: """Execute a script. @@ -261,6 +248,7 @@ def execute_script( Status message for the operation. """ connection = self._connection(connection) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) # For DDL statements, don't pass parameters to execute # SQLite doesn't support parameters for DDL statements @@ -268,8 +256,7 @@ def execute_script( if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] else: - sql, parameters = self._process_sql_params(sql, parameters) - cursor.execute(sql, parameters) # type: ignore[arg-type] + cursor.execute(sql, parameters) return cast("str", cursor.statusmessage) if hasattr(cursor, "statusmessage") else "DONE" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] @@ -278,8 +265,10 @@ def execute_script_returning( sql: str, parameters: Optional["StatementParameterType"] = None, /, + *, connection: Optional["Connection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Execute a script and return result. @@ -287,7 +276,7 @@ def execute_script_returning( The first row of results. """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: if not parameters: diff --git a/sqlspec/base.py b/sqlspec/base.py index 4e727950a..798fa27ec 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -1,4 +1,4 @@ -# ruff: noqa: PLR6301 +# ruff: noqa: PLR6301, PLR0912, PLR0915, C901, PLR0911 import re from abc import ABC, abstractmethod from collections.abc import Awaitable @@ -16,7 +16,10 @@ overload, ) -from sqlspec.exceptions import NotFoundError +import sqlglot +from sqlglot import exp + +from sqlspec.exceptions import NotFoundError, SQLParsingError from sqlspec.typing import ModelDTOT, StatementParameterType if TYPE_CHECKING: @@ -456,70 +459,169 @@ def _process_sql_statement(self, sql: str) -> str: return sql def _process_sql_params( - self, sql: str, parameters: "Optional[StatementParameterType]" = None + self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, **kwargs: Any ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": """Process SQL query and parameters for DB-API execution. - Converts named parameters (:name) to positional parameters specified by `self.param_style` - if the input parameters are a dictionary. + Uses sqlglot to parse named parameters (:name) if parameters is a dictionary, + and converts them to the driver's `param_style`. + Handles single value parameters by wrapping them in a tuple. Args: sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). + parameters: The parameters for the query (dict, tuple, list, single value, or None). + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: A tuple containing the processed SQL string and the processed parameters - (always a tuple or None if the input was a dictionary, otherwise the original type). + (tuple for named/single params, original list/tuple for positional, None if no params). Raises: - ValueError: If a named parameter in the SQL is not found in the dictionary - or if a parameter in the dictionary is not used in the SQL. + ValueError: If parameter validation fails (missing/extra keys for dicts, + mixing named/positional placeholders with dicts). + ImportError: If sqlglot is not installed. """ - if not isinstance(parameters, dict) or not parameters: - # If parameters are not a dict, or empty dict, assume positional/no params - # Let the underlying driver handle tuples/lists directly + # 1. Handle None and kwargs + if parameters is None and not kwargs: + return self._process_sql_statement(sql), None + + # 2. Merge parameters with kwargs if parameters is a dict + parameters = {**parameters, **kwargs} if isinstance(parameters, dict) else kwargs if kwargs else parameters + + # 3. Handle dictionary parameters using sqlglot + if isinstance(parameters, dict): + if not parameters: + # Return early for empty dict + return self._process_sql_statement(sql), parameters + + # First check if there are any :param style placeholders using regex + regex_placeholders = [] + for match in PARAM_REGEX.finditer(sql): + if match.group("dquote") is not None or match.group("squote") is not None: + continue + var_name = match.group("var_name") + if var_name is not None: + regex_placeholders.append(var_name) + + try: + expression = sqlglot.parse_one(sql) + except Exception as e: + # If sqlglot parsing fails but regex found placeholders, use regex approach + if regex_placeholders: + # Use regex approach as fallback + processed_sql = sql + param_values = [] + for key, value in parameters.items(): + if key in regex_placeholders: + processed_sql = processed_sql.replace(f":{key}", self.param_style) + param_values.append(value) + + # Validate that all placeholders were found + if len(param_values) != len(regex_placeholders): + msg = f"Not all placeholders found in parameters: {set(regex_placeholders) - set(parameters.keys())}" + raise SQLParsingError(msg) from e + + return self._process_sql_statement(processed_sql), tuple(param_values) + + msg = f"sqlglot failed to parse SQL: {e}" + raise SQLParsingError(msg) from e + + placeholders = list(expression.find_all(exp.Parameter)) + placeholder_names: list[str] = [] + has_unnamed = False + for p in placeholders: + if p.name: + placeholder_names.append(p.name) + else: + has_unnamed = True # Found unnamed placeholder like '?' + + # If sqlglot didn't find any placeholders but regex did, use regex approach + if not placeholder_names and regex_placeholders: + processed_sql = sql + param_values = [] + for key, value in parameters.items(): + if key in regex_placeholders: + processed_sql = processed_sql.replace(f":{key}", self.param_style) + param_values.append(value) + + # Validate that all placeholders were found + if len(param_values) != len(regex_placeholders): + msg = ( + f"Not all placeholders found in parameters: {set(regex_placeholders) - set(parameters.keys())}" + ) + raise SQLParsingError(msg) + + return self._process_sql_statement(processed_sql), tuple(param_values) + + if has_unnamed: + msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') in the SQL query." + raise SQLParsingError(msg) + + if not placeholder_names: + # If no named placeholders found, but dict was provided, raise error. + # (We already handled the empty dict case above) + msg = "Dictionary parameters provided, but no named placeholders found in the SQL query." + raise SQLParsingError(msg) + + # Validation + provided_keys = set(parameters.keys()) + required_keys = set(placeholder_names) + + missing_keys = required_keys - provided_keys + if missing_keys: + msg = f"Named parameters found in SQL but not provided in parameters dictionary: {missing_keys}" + raise SQLParsingError(msg) + + extra_keys = provided_keys - required_keys + if extra_keys: + msg = f"Parameters provided but not found in SQL: {extra_keys}" + raise SQLParsingError(msg) # Strict check + + # Build ordered tuple of parameters + ordered_params = tuple(parameters[name] for name in placeholder_names) + + # Replace :name with self.param_style using regex for safety + processed_sql = "" + last_end = 0 + params_iter = iter(placeholder_names) # Ensure order correctness during replacement + + for match in PARAM_REGEX.finditer(sql): + if match.group("dquote") is not None or match.group("squote") is not None: + processed_sql += sql[last_end : match.end()] + last_end = match.end() + continue + + var_name = match.group("var_name") + if var_name is None: + processed_sql += sql[last_end : match.end()] + last_end = match.end() + continue + + expected_param = next(params_iter, None) + if var_name != expected_param: + msg = f"Internal parameter processing mismatch: Regex found ':{var_name}' but expected ':{expected_param}' based on sqlglot parse order." + raise SQLParsingError(msg) + + # Replace :param with param_style + start_replace = match.start("var_name") - 1 # Include the ':' + processed_sql += sql[last_end:start_replace] + self.param_style + last_end = match.end("var_name") + + processed_sql += sql[last_end:] # Append remaining part + + final_sql = self._process_sql_statement(processed_sql) + return final_sql, ordered_params + + # 4. Handle list/tuple parameters (positional) + if isinstance(parameters, (list, tuple)): + # Let the underlying driver handle these directly return self._process_sql_statement(sql), parameters - processed_sql = "" - processed_params_list: list[Any] = [] - last_end = 0 - found_params: set[str] = set() - - for match in PARAM_REGEX.finditer(sql): - if match.group("dquote") is not None or match.group("squote") is not None: - # Skip placeholders within quotes - continue - - var_name = match.group("var_name") - if var_name is None: # Should not happen with the regex, but safeguard - continue - - if var_name not in parameters: - msg = f"Named parameter ':{var_name}' found in SQL but not provided in parameters dictionary." - raise ValueError(msg) - - # Append segment before the placeholder + the leading character + the driver's positional placeholder - # The match.start("var_name") -1 includes the character before the ':' - processed_sql += sql[last_end : match.start("var_name")] + self.param_style - processed_params_list.append(parameters[var_name]) - found_params.add(var_name) - last_end = match.end("var_name") - - # Append the rest of the SQL string - processed_sql += sql[last_end:] - - # Check if all provided parameters were used - unused_params = set(parameters.keys()) - found_params - if unused_params: - msg = f"Parameters provided but not found in SQL: {unused_params}" - # Depending on desired strictness, this could be a warning or an error - # For now, let's raise an error for clarity - raise ValueError(msg) - - processed_params = tuple(processed_params_list) - # Pass the processed SQL through the driver-specific processor if needed - final_sql = self._process_sql_statement(processed_sql) - return final_sql, processed_params + # 5. Handle single value parameters + # If it wasn't None, dict, list, or tuple, it must be a single value + processed_params: tuple[Any, ...] = (parameters,) + # Assuming single value maps to a single positional placeholder. + return self._process_sql_statement(sql), processed_params class SyncArrowBulkOperationsMixin(Generic[ConnectionT]): @@ -536,7 +638,9 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, + **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] """Execute a SQL query and return results as an Apache Arrow Table. @@ -544,6 +648,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: The SQL query string. parameters: Parameters for the query. connection: Optional connection override. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: An Apache Arrow Table containing the query results. @@ -563,8 +668,10 @@ def select( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: Optional[type[ModelDTOT]] = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": ... @abstractmethod @@ -573,8 +680,10 @@ def select_one( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, schema_type: Optional[type[ModelDTOT]] = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": ... @abstractmethod @@ -583,8 +692,10 @@ def select_one_or_none( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, schema_type: Optional[type[ModelDTOT]] = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ... @abstractmethod @@ -593,8 +704,10 @@ def select_value( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, schema_type: Optional[type[T]] = None, + **kwargs: Any, ) -> "Union[Any, T]": ... @abstractmethod @@ -603,8 +716,10 @@ def select_value_or_none( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, schema_type: Optional[type[T]] = None, + **kwargs: Any, ) -> "Optional[Union[Any, T]]": ... @abstractmethod @@ -613,7 +728,9 @@ def insert_update_delete( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, + **kwargs: Any, ) -> int: ... @abstractmethod @@ -622,8 +739,10 @@ def insert_update_delete_returning( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, schema_type: Optional[type[ModelDTOT]] = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ... @abstractmethod @@ -632,7 +751,9 @@ def execute_script( sql: str, parameters: Optional[StatementParameterType] = None, /, + *, connection: Optional[ConnectionT] = None, + **kwargs: Any, ) -> str: ... @@ -647,7 +768,9 @@ async def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, + **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] """Execute a SQL query and return results as an Apache Arrow Table. @@ -655,6 +778,7 @@ async def select_arrow( # pyright: ignore[reportUnknownParameterType] sql: The SQL query string. parameters: Parameters for the query. connection: Optional connection override. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: An Apache Arrow Table containing the query results. @@ -674,8 +798,10 @@ async def select( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": ... @abstractmethod @@ -684,8 +810,10 @@ async def select_one( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": ... @abstractmethod @@ -694,8 +822,10 @@ async def select_one_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": ... @abstractmethod @@ -704,8 +834,10 @@ async def select_value( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Union[Any, T]": ... @abstractmethod @@ -714,8 +846,10 @@ async def select_value_or_none( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[T]]" = None, + **kwargs: Any, ) -> "Optional[Union[Any, T]]": ... @abstractmethod @@ -724,7 +858,9 @@ async def insert_update_delete( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, + **kwargs: Any, ) -> int: ... @abstractmethod @@ -733,8 +869,10 @@ async def insert_update_delete_returning( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": ... @abstractmethod @@ -743,7 +881,9 @@ async def execute_script( sql: str, parameters: "Optional[StatementParameterType]" = None, /, + *, connection: "Optional[ConnectionT]" = None, + **kwargs: Any, ) -> str: ... From 2e6417941adbe76411f82eb92666e4f337ef50c2 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 20 Apr 2025 03:03:15 +0000 Subject: [PATCH 2/2] fix: more --- sqlspec/adapters/adbc/driver.py | 153 ++++++- sqlspec/adapters/aiosqlite/driver.py | 1 + sqlspec/adapters/asyncmy/driver.py | 1 + sqlspec/adapters/asyncpg/driver.py | 215 +++++++++- sqlspec/adapters/duckdb/driver.py | 36 +- sqlspec/adapters/oracledb/driver.py | 2 + sqlspec/adapters/psycopg/driver.py | 135 ++++++- sqlspec/adapters/sqlite/driver.py | 23 +- sqlspec/base.py | 205 ++-------- sqlspec/exceptions.py | 30 ++ sqlspec/statement.py | 373 ++++++++++++++++++ sqlspec/typing.py | 2 +- tests/fixtures/sql_utils.py | 6 +- .../test_adbc/test_driver_duckdb.py | 166 ++++++++ .../test_adbc/test_driver_sqlite.py | 58 +++ .../test_adapters/test_asyncpg/__init__.py | 0 .../test_asyncpg/test_connection.py | 42 ++ .../test_adapters/test_asyncpg/test_driver.py | 275 +++++++++++++ 18 files changed, 1452 insertions(+), 271 deletions(-) create mode 100644 sqlspec/statement.py create mode 100644 tests/integration/test_adapters/test_asyncpg/__init__.py create mode 100644 tests/integration/test_adapters/test_asyncpg/test_connection.py create mode 100644 tests/integration/test_adapters/test_asyncpg/test_driver.py diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 9123926f7..fc9e5c4d9 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -1,30 +1,35 @@ import contextlib +import logging import re from collections.abc import Generator from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast -from adbc_driver_manager.dbapi import Connection -from adbc_driver_manager.dbapi import Cursor as DbapiCursor +from adbc_driver_manager.dbapi import Connection, Cursor -from sqlspec._typing import ArrowTable from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T +from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError +from sqlspec.statement import SQLStatement +from sqlspec.typing import ArrowTable, StatementParameterType if TYPE_CHECKING: from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType __all__ = ("AdbcDriver",) +logger = logging.getLogger("sqlspec") + -# Regex to find :param or %(param)s style placeholders, skipping those inside quotes PARAM_REGEX = re.compile( - r""" - (?P"([^"]|\\")*") | # Double-quoted strings - (?P'([^']|\\')*') | # Single-quoted strings - : (?P[a-zA-Z_][a-zA-Z0-9_]*) | # :var_name - % \( (?P[a-zA-Z_][a-zA-Z0-9_]*) \) s # %(var_name)s + r"""(?"(?:[^"]|"")*") | # Double-quoted strings + (?P'(?:[^']|'')*') | # Single-quoted strings + (?P--.*?\n|\/\*.*?\*\/) | # SQL comments + (?P[:\$])(?P[a-zA-Z_][a-zA-Z0-9_]*) # :name or $name identifier + ) """, - re.VERBOSE, + re.VERBOSE | re.DOTALL, ) @@ -37,22 +42,124 @@ class AdbcDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterPr def __init__(self, connection: "Connection") -> None: """Initialize the ADBC driver adapter.""" self.connection = connection - # Potentially introspect connection.paramstyle here if needed in the future - # For now, assume 'qmark' based on typical ADBC DBAPI behavior + self.dialect = self._get_dialect(connection) + + @staticmethod + def _get_dialect(connection: "Connection") -> str: # noqa: PLR0911 + """Get the database dialect based on the driver name. + + Args: + connection: The ADBC connection object. + + Returns: + The database dialect. + """ + driver_name = connection.adbc_get_info()["vendor_name"].lower() + if "postgres" in driver_name: + return "postgres" + if "bigquery" in driver_name: + return "bigquery" + if "sqlite" in driver_name: + return "sqlite" + if "duckdb" in driver_name: + return "duckdb" + if "mysql" in driver_name: + return "mysql" + if "snowflake" in driver_name: + return "snowflake" + return "postgres" # default to postgresql dialect @staticmethod - def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "DbapiCursor": + def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor": return connection.cursor(*args, **kwargs) @contextmanager - def _with_cursor(self, connection: "Connection") -> Generator["DbapiCursor", None, None]: - cursor: DbapiCursor = self._cursor(connection) + def _with_cursor(self, connection: "Connection") -> Generator["Cursor", None, None]: + cursor = self._cursor(connection) try: yield cursor finally: with contextlib.suppress(Exception): cursor.close() # type: ignore[no-untyped-call] + def _process_sql_params( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + **kwargs: Any, + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + # Determine effective parameter type *before* calling SQLStatement + merged_params_type = dict if kwargs else type(parameters) + + # If ADBC + sqlite/duckdb + dictionary params, handle conversion manually + if self.dialect in {"sqlite", "duckdb"} and merged_params_type is dict: + logger.debug( + "ADBC/%s with dict params; bypassing SQLStatement conversion, manually converting to '?' positional.", + self.dialect, + ) + + # Combine parameters and kwargs into the actual dictionary to use + parameter_dict = {} # type: ignore[var-annotated] + if isinstance(parameters, dict): + parameter_dict.update(parameters) + if kwargs: + parameter_dict.update(kwargs) + + # Define regex locally to find :name or $name + + processed_sql_parts: list[str] = [] + ordered_params = [] + last_end = 0 + found_params_regex: list[str] = [] + + for match in PARAM_REGEX.finditer(sql): # Use original sql + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): + var_name = match.group("var_name") + leading_char = match.group("lead") # : or $ + found_params_regex.append(var_name) + # Use match span directly for replacement + start = match.start() + end = match.end() + + if var_name not in parameter_dict: + msg = f"Named parameter '{leading_char}{var_name}' found in SQL but not provided. SQL: {sql}" + raise SQLParsingError(msg) + + processed_sql_parts.extend((sql[last_end:start], "?")) # Force ? style + ordered_params.append(parameter_dict[var_name]) + last_end = end + + processed_sql_parts.append(sql[last_end:]) + + if not found_params_regex and parameter_dict: + msg = f"ADBC/{self.dialect}: Dict params provided, but no :name or $name placeholders found. SQL: {sql}" + raise ParameterStyleMismatchError(msg) + + # Key validation + provided_keys = set(parameter_dict.keys()) + missing_keys = set(found_params_regex) - provided_keys + if missing_keys: + msg = ( + f"Named parameters found in SQL ({found_params_regex}) but not provided: {missing_keys}. SQL: {sql}" + ) + raise SQLParsingError(msg) + extra_keys = provided_keys - set(found_params_regex) + if extra_keys: + logger.debug("Extra parameters provided for ADBC/%s: %s", self.dialect, extra_keys) + # Allow extra keys + + final_sql = "".join(processed_sql_parts) + final_params = tuple(ordered_params) + return final_sql, final_params + # For all other cases (other dialects, or non-dict params for sqlite/duckdb), + # use the standard SQLStatement processing. + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + return stmt.process() + def select( self, sql: str, @@ -224,17 +331,21 @@ def insert_update_delete_returning( """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - column_names: list[str] = [] - with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchall() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if len(result) == 0: # pyright: ignore[reportUnknownArgumentType] + if not result: return None + + first_row = result[0] + column_names = [c[0] for c in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) # pyright: ignore[reportUnknownArgumentType] - return dict(zip(column_names, result[0])) # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType] + + result_dict = dict(zip(column_names, first_row)) + + if schema_type is None: + return result_dict + return cast("ModelDTOT", schema_type(**result_dict)) def execute_script( self, diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index dbfea15fa..b61ee3b2e 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -17,6 +17,7 @@ class AiosqliteDriver(AsyncDriverAdapterProtocol["Connection"]): """SQLite Async Driver Adapter.""" connection: "Connection" + dialect: str = "sqlite" def __init__(self, connection: "Connection") -> None: self.connection = connection diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index 2d9e06f98..2ffa02be2 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -18,6 +18,7 @@ class AsyncmyDriver(AsyncDriverAdapterProtocol["Connection"]): """Asyncmy MySQL/MariaDB Driver Adapter.""" connection: "Connection" + dialect: str = "mysql" def __init__(self, connection: "Connection") -> None: self.connection = connection diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index edf3fc9ff..28fa4b63d 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -1,9 +1,13 @@ +import logging +import re from typing import TYPE_CHECKING, Any, Optional, Union, cast from asyncpg import Connection from typing_extensions import TypeAlias from sqlspec.base import AsyncDriverAdapterProtocol, T +from sqlspec.exceptions import SQLParsingError +from sqlspec.statement import PARAM_REGEX, SQLStatement if TYPE_CHECKING: from asyncpg.connection import Connection @@ -13,6 +17,18 @@ __all__ = ("AsyncpgConnection", "AsyncpgDriver") +logger = logging.getLogger("sqlspec") + +# Regex to find '?' placeholders, skipping those inside quotes or SQL comments +# Simplified version, assumes standard SQL quoting/comments +QMARK_REGEX = re.compile( + r"""(?P"[^"]*") | # Double-quoted strings + (?P\'[^\']*\') | # Single-quoted strings + (?P--[^\n]*|/\*.*?\*/) | # SQL comments (single/multi-line) + (?P\?) # The question mark placeholder + """, + re.VERBOSE | re.DOTALL, +) AsyncpgConnection: TypeAlias = "Union[Connection[Any], PoolConnectionProxy[Any]]" # pyright: ignore[reportMissingTypeArgument] @@ -21,10 +37,165 @@ class AsyncpgDriver(AsyncDriverAdapterProtocol["AsyncpgConnection"]): """AsyncPG Postgres Driver Adapter.""" connection: "AsyncpgConnection" + dialect: str = "postgres" def __init__(self, connection: "AsyncpgConnection") -> None: self.connection = connection + def _process_sql_params( # noqa: C901, PLR0912, PLR0915 + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + **kwargs: Any, + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + # Use SQLStatement for parameter validation and merging first + # It also handles potential dialect-specific logic if implemented there. + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + sql, parameters = stmt.process() + + # Case 1: Parameters are effectively a dictionary (either passed as dict or via kwargs merged by SQLStatement) + if isinstance(parameters, dict): + processed_sql_parts: list[str] = [] + ordered_params = [] + last_end = 0 + param_index = 1 + found_params_regex: list[str] = [] + + # Manually parse the PROCESSED SQL for :name -> $n conversion + for match in PARAM_REGEX.finditer(sql): + # Skip matches inside quotes or comments + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): # Finds :var_name + var_name = match.group("var_name") + found_params_regex.append(var_name) + start = match.start("var_name") - 1 # Include the ':' + end = match.end("var_name") + + # SQLStatement should have already validated parameter existence, + # but we double-check here during ordering. + if var_name not in parameters: + # This should ideally not happen if SQLStatement validation is robust. + msg = ( + f"Named parameter ':{var_name}' found in SQL but missing from processed parameters. " + f"Processed SQL: {sql}" + ) + raise SQLParsingError(msg) + + processed_sql_parts.extend((sql[last_end:start], f"${param_index}")) + ordered_params.append(parameters[var_name]) + last_end = end + param_index += 1 + + processed_sql_parts.append(sql[last_end:]) + final_sql = "".join(processed_sql_parts) + + # --- Validation --- + # Check if named placeholders were found if dict params were provided + # SQLStatement might handle this validation, but a warning here can be useful. + if not found_params_regex and parameters: + logger.warning( + "Dict params provided (%s), but no :name placeholders found. SQL: %s", + list(parameters.keys()), + sql, + ) + # If no placeholders, return original SQL from SQLStatement and empty tuple for asyncpg + return sql, () + + # Additional checks (potentially redundant if SQLStatement covers them): + # 1. Ensure all found placeholders have corresponding params (covered by check inside loop) + # 2. Ensure all provided params correspond to a placeholder + provided_keys = set(parameters.keys()) + found_keys = set(found_params_regex) + unused_keys = provided_keys - found_keys + if unused_keys: + # SQLStatement might handle this, but log a warning just in case. + logger.warning( + "Parameters provided but not used in SQL: %s. SQL: %s", + unused_keys, + sql, + ) + + return final_sql, tuple(ordered_params) # asyncpg expects a sequence + + # Case 2: Parameters are effectively a sequence/scalar (merged by SQLStatement) + if isinstance(parameters, (list, tuple)): + # Parameters are a sequence, need to convert ? -> $n + sequence_processed_parts: list[str] = [] + param_index = 1 + last_end = 0 + qmark_found = False + + # Manually parse the PROCESSED SQL to find '?' outside comments/quotes and convert to $n + for match in QMARK_REGEX.finditer(sql): + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue # Skip quotes and comments + + if match.group("qmark"): + qmark_found = True + start = match.start("qmark") + end = match.end("qmark") + sequence_processed_parts.extend((sql[last_end:start], f"${param_index}")) + last_end = end + param_index += 1 + + sequence_processed_parts.append(sql[last_end:]) + final_sql = "".join(sequence_processed_parts) + + # --- Validation --- + # Check if '?' was found if parameters were provided + if parameters and not qmark_found: + # SQLStatement might allow this, log a warning. + logger.warning( + "Sequence/scalar parameters provided, but no '?' placeholders found. SQL: %s", + sql, + ) + # Return PROCESSED SQL from SQLStatement as no conversion happened here + return sql, parameters + + # Check parameter count match (using count from manual parsing vs count from stmt) + expected_params = param_index - 1 + actual_params = len(parameters) + if expected_params != actual_params: + msg = ( + f"Parameter count mismatch: Processed SQL expected {expected_params} parameters ('$n'), " + f"but {actual_params} were provided by SQLStatement. " + f"Final Processed SQL: {final_sql}" + ) + raise SQLParsingError(msg) + + return final_sql, parameters + + # Case 3: Parameters are None (as determined by SQLStatement) + # processed_params is None + # Check if the SQL contains any placeholders unexpectedly + # Check for :name style + named_placeholders_found = False + for match in PARAM_REGEX.finditer(sql): + if not (match.group("dquote") or match.group("squote") or match.group("comment")) and match.group( + "var_name" + ): + named_placeholders_found = True + break + if named_placeholders_found: + msg = f"Processed SQL contains named parameters (:name) but no parameters were provided. SQL: {sql}" + raise SQLParsingError(msg) + + # Check for ? style + qmark_placeholders_found = False + for match in QMARK_REGEX.finditer(sql): + if not (match.group("dquote") or match.group("squote") or match.group("comment")) and match.group("qmark"): + qmark_placeholders_found = True + break + if qmark_placeholders_found: + msg = f"Processed SQL contains positional parameters (?) but no parameters were provided. SQL: {sql}" + raise SQLParsingError(msg) + + # No parameters provided and none found in SQL, return original SQL from SQLStatement and empty tuple + return sql, () # asyncpg expects a sequence, even if empty + async def select( self, sql: str, @@ -49,9 +220,9 @@ async def select( """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - parameters = parameters if parameters is not None else () + parameters = parameters if parameters is not None else {} - results = await connection.fetch(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + results = await connection.fetch(sql, *parameters) # pyright: ignore if not results: return [] if schema_type is None: @@ -82,8 +253,8 @@ async def select_one( """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - parameters = parameters if parameters is not None else () - result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + parameters = parameters if parameters is not None else {} + result = await connection.fetchrow(sql, *parameters) # pyright: ignore result = self.check_not_found(result) if schema_type is None: @@ -115,8 +286,8 @@ async def select_one_or_none( """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - parameters = parameters if parameters is not None else () - result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + parameters = parameters if parameters is not None else {} + result = await connection.fetchrow(sql, *parameters) # pyright: ignore if result is None: return None if schema_type is None: @@ -148,12 +319,12 @@ async def select_value( """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - parameters = parameters if parameters is not None else () - result = await connection.fetchval(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + parameters = parameters if parameters is not None else {} + result = await connection.fetchval(sql, *parameters) # pyright: ignore result = self.check_not_found(result) if schema_type is None: - return result[0] - return schema_type(result[0]) # type: ignore[call-arg] + return result + return schema_type(result) # type: ignore[call-arg] async def select_value_or_none( self, @@ -172,13 +343,13 @@ async def select_value_or_none( """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - parameters = parameters if parameters is not None else () - result = await connection.fetchval(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + parameters = parameters if parameters is not None else {} + result = await connection.fetchval(sql, *parameters) # pyright: ignore if result is None: return None if schema_type is None: - return result[0] - return schema_type(result[0]) # type: ignore[call-arg] + return result + return schema_type(result) # type: ignore[call-arg] async def insert_update_delete( self, @@ -202,8 +373,8 @@ async def insert_update_delete( """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - parameters = parameters if parameters is not None else () - status = await connection.execute(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + parameters = parameters if parameters is not None else {} + status = await connection.execute(sql, *parameters) # pyright: ignore # AsyncPG returns a string like "INSERT 0 1" where the last number is the affected rows try: return int(status.split()[-1]) # pyright: ignore[reportUnknownMemberType] @@ -234,8 +405,8 @@ async def insert_update_delete_returning( """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - parameters = parameters if parameters is not None else () - result = await connection.fetchrow(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + parameters = parameters if parameters is not None else {} + result = await connection.fetchrow(sql, *parameters) # pyright: ignore if result is None: return None if schema_type is None: @@ -265,5 +436,9 @@ async def execute_script( """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - parameters = parameters if parameters is not None else () - return await connection.execute(sql, *parameters) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + parameters = parameters if parameters is not None else {} + return await connection.execute(sql, *parameters) # pyright: ignore + + def _connection(self, connection: Optional["AsyncpgConnection"] = None) -> "AsyncpgConnection": + """Return the connection to use. If None, use the default connection.""" + return connection if connection is not None else self.connection diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index 80677c574..18a4abec0 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -1,7 +1,6 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast -from sqlspec._typing import ArrowTable from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T if TYPE_CHECKING: @@ -19,6 +18,7 @@ class DuckDBDriver(SyncArrowBulkOperationsMixin["DuckDBPyConnection"], SyncDrive connection: "DuckDBPyConnection" use_cursor: bool = True + dialect: str = "duckdb" def __init__(self, connection: "DuckDBPyConnection", use_cursor: bool = True) -> None: self.connection = connection @@ -27,7 +27,6 @@ def __init__(self, connection: "DuckDBPyConnection", use_cursor: bool = True) -> # --- Helper Methods --- # def _cursor(self, connection: "DuckDBPyConnection") -> "DuckDBPyConnection": if self.use_cursor: - # Ignore lack of type hint on cursor() return connection.cursor() return connection @@ -40,9 +39,9 @@ def _with_cursor(self, connection: "DuckDBPyConnection") -> "Generator[DuckDBPyC finally: cursor.close() else: - yield connection # Yield the connection directly + yield connection - # --- Public API Methods (Original Implementation + _process_sql_params) --- # + # --- Public API Methods --- # def select( self, @@ -54,18 +53,6 @@ def select( schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "list[Union[ModelDTOT, dict[str, Any]]]": - """Fetch data from the database. - - Args: - sql: SQL statement. - parameters: Query parameters. - connection: Optional connection to use. - schema_type: Optional schema class for the result. - **kwargs: Additional keyword arguments. - - Returns: - List of row data as either model instances or dictionaries. - """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: @@ -91,8 +78,7 @@ def select_one( **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -116,7 +102,6 @@ def select_one_or_none( ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -139,8 +124,7 @@ def select_value( **kwargs: Any, ) -> "Union[T, Any]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) - + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -160,7 +144,7 @@ def select_value_or_none( **kwargs: Any, ) -> "Optional[Union[T, Any]]": connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters) + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = cursor.fetchone() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @@ -233,13 +217,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] *, connection: "Optional[DuckDBPyConnection]" = None, **kwargs: Any, - ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] - """Execute a SQL query and return results as an Apache Arrow Table. - - Returns: - An Apache Arrow Table containing the query results. - """ - + ) -> "ArrowTable": connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index bd3751f93..40e9cabbd 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -25,6 +25,7 @@ class OracleSyncDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAda """Oracle Sync Driver Adapter.""" connection: "Connection" + dialect: str = "oracle" def __init__(self, connection: "Connection") -> None: self.connection = connection @@ -304,6 +305,7 @@ class OracleAsyncDriver( """Oracle Async Driver Adapter.""" connection: "AsyncConnection" + dialect: str = "oracle" def __init__(self, connection: "AsyncConnection") -> None: self.connection = connection diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 09375d970..345593c02 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -1,9 +1,12 @@ +import logging from contextlib import asynccontextmanager, contextmanager from typing import TYPE_CHECKING, Any, Optional, Union, cast from psycopg.rows import dict_row from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol, T +from sqlspec.exceptions import SQLParsingError +from sqlspec.statement import PARAM_REGEX, SQLStatement if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator @@ -12,6 +15,8 @@ from sqlspec.typing import ModelDTOT, StatementParameterType +logger = logging.getLogger("sqlspec") + __all__ = ("PsycopgAsyncDriver", "PsycopgSyncDriver") @@ -19,11 +24,63 @@ class PsycopgSyncDriver(SyncDriverAdapterProtocol["Connection"]): """Psycopg Sync Driver Adapter.""" connection: "Connection" - param_style: str = "%s" + dialect: str = "postgres" def __init__(self, connection: "Connection") -> None: self.connection = connection + def _process_sql_params( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + **kwargs: Any, + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Process SQL and parameters, converting :name -> %(name)s if needed.""" + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + processed_sql, processed_params = stmt.process() + + if isinstance(processed_params, dict): + parameter_dict = processed_params + processed_sql_parts: list[str] = [] + last_end = 0 + found_params_regex: list[str] = [] + + for match in PARAM_REGEX.finditer(processed_sql): + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): + var_name = match.group("var_name") + found_params_regex.append(var_name) + start = match.start("var_name") - 1 + end = match.end("var_name") + + if var_name not in parameter_dict: + msg = ( + f"Named parameter ':{var_name}' found in SQL but missing from processed parameters. " + f"Processed SQL: {processed_sql}" + ) + raise SQLParsingError(msg) + + processed_sql_parts.extend((processed_sql[last_end:start], f"%({var_name})s")) + last_end = end + + processed_sql_parts.append(processed_sql[last_end:]) + final_sql = "".join(processed_sql_parts) + + if not found_params_regex and parameter_dict: + logger.warning( + "Dict params provided (%s), but no :name placeholders found. SQL: %s", + list(parameter_dict.keys()), + processed_sql, + ) + return processed_sql, parameter_dict + + return final_sql, parameter_dict + + return processed_sql, processed_params + @staticmethod @contextmanager def _with_cursor(connection: "Connection") -> "Generator[Any, None, None]": @@ -132,7 +189,8 @@ def select_value( cursor.execute(sql, parameters) row = cursor.fetchone() row = self.check_not_found(row) - val = next(iter(row)) + val = next(iter(row.values())) if row else None + val = self.check_not_found(val) if schema_type is not None: return schema_type(val) # type: ignore[call-arg] return val @@ -147,10 +205,10 @@ def select_value_or_none( schema_type: "Optional[type[T]]" = None, **kwargs: Any, ) -> "Optional[Union[T, Any]]": - """Fetch one row from the database. + """Fetch a single value from the database. Returns: - The first row of the query results. + The first value from the first row of results, or None if no results. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) @@ -159,7 +217,9 @@ def select_value_or_none( row = cursor.fetchone() if row is None: return None - val = next(iter(row)) + val = next(iter(row.values())) if row else None + if val is None: + return None if schema_type is not None: return schema_type(val) # type: ignore[call-arg] return val @@ -230,18 +290,70 @@ def execute_script( sql, parameters = self._process_sql_params(sql, parameters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) - return str(cursor.rowcount) + return str(cursor.statusmessage) if cursor.statusmessage is not None else "DONE" class PsycopgAsyncDriver(AsyncDriverAdapterProtocol["AsyncConnection"]): """Psycopg Async Driver Adapter.""" connection: "AsyncConnection" - param_style: str = "%s" + dialect: str = "postgres" def __init__(self, connection: "AsyncConnection") -> None: self.connection = connection + def _process_sql_params( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + **kwargs: Any, + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Process SQL and parameters, converting :name -> %(name)s if needed.""" + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + processed_sql, processed_params = stmt.process() + + if isinstance(processed_params, dict): + parameter_dict = processed_params + processed_sql_parts: list[str] = [] + last_end = 0 + found_params_regex: list[str] = [] + + for match in PARAM_REGEX.finditer(processed_sql): + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): + var_name = match.group("var_name") + found_params_regex.append(var_name) + start = match.start("var_name") - 1 + end = match.end("var_name") + + if var_name not in parameter_dict: + msg = ( + f"Named parameter ':{var_name}' found in SQL but missing from processed parameters. " + f"Processed SQL: {processed_sql}" + ) + raise SQLParsingError(msg) + + processed_sql_parts.extend((processed_sql[last_end:start], f"%({var_name})s")) + last_end = end + + processed_sql_parts.append(processed_sql[last_end:]) + final_sql = "".join(processed_sql_parts) + + if not found_params_regex and parameter_dict: + logger.warning( + "Dict params provided (%s), but no :name placeholders found. SQL: %s", + list(parameter_dict.keys()), + processed_sql, + ) + return processed_sql, parameter_dict + + return final_sql, parameter_dict + + return processed_sql, processed_params + @staticmethod @asynccontextmanager async def _with_cursor(connection: "AsyncConnection") -> "AsyncGenerator[Any, None]": @@ -354,7 +466,8 @@ async def select_value( await cursor.execute(sql, parameters) row = await cursor.fetchone() row = self.check_not_found(row) - val = next(iter(row)) + val = next(iter(row.values())) if row else None + val = self.check_not_found(val) if schema_type is not None: return schema_type(val) # type: ignore[call-arg] return val @@ -382,7 +495,9 @@ async def select_value_or_none( row = await cursor.fetchone() if row is None: return None - val = next(iter(row)) + val = next(iter(row.values())) if row else None + if val is None: + return None if schema_type is not None: return schema_type(val) # type: ignore[call-arg] return val @@ -460,4 +575,4 @@ async def execute_script( async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) - return str(cursor.rowcount) + return str(cursor.statusmessage) if cursor.statusmessage is not None else "DONE" diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 1e8b20bde..7c5916abb 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -16,6 +16,7 @@ class SqliteDriver(SyncDriverAdapterProtocol["Connection"]): """SQLite Sync Driver Adapter.""" connection: "Connection" + dialect: str = "sqlite" def __init__(self, connection: "Connection") -> None: self.connection = connection @@ -228,10 +229,23 @@ def insert_update_delete_returning( result = cursor.fetchall() if len(result) == 0: return None + + # Get column names from cursor description column_names = [c[0] for c in cursor.description or []] + + # Get the first row's values - ensure we're getting the actual values + row_values = result[0] + + # Debug print to see what we're getting + + # Create dictionary mapping column names to values + result_dict = {} + for i, col_name in enumerate(column_names): + result_dict[col_name] = row_values[i] + if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) - return dict(zip(column_names, result[0])) + return cast("ModelDTOT", schema_type(**result_dict)) + return result_dict def execute_script( self, @@ -250,15 +264,14 @@ def execute_script( connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - # For DDL statements, don't pass parameters to execute - # SQLite doesn't support parameters for DDL statements + # The _process_sql_params handles parameter formatting for the dialect. with self._with_cursor(connection) as cursor: if not parameters: cursor.execute(sql) # pyright: ignore[reportUnknownMemberType] else: cursor.execute(sql, parameters) - return cast("str", cursor.statusmessage) if hasattr(cursor, "statusmessage") else "DONE" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] + return cast("str", cursor.statusmessage) if hasattr(cursor, "statusmessage") else "DONE" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue] def execute_script_returning( self, diff --git a/sqlspec/base.py b/sqlspec/base.py index 798fa27ec..855432b17 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -1,4 +1,4 @@ -# ruff: noqa: PLR6301, PLR0912, PLR0915, C901, PLR0911 +# ruff: noqa: PLR6301 import re from abc import ABC, abstractmethod from collections.abc import Awaitable @@ -16,10 +16,8 @@ overload, ) -import sqlglot -from sqlglot import exp - -from sqlspec.exceptions import NotFoundError, SQLParsingError +from sqlspec.exceptions import NotFoundError +from sqlspec.statement import SQLStatement from sqlspec.typing import ModelDTOT, StatementParameterType if TYPE_CHECKING: @@ -37,6 +35,7 @@ "NoPoolAsyncConfig", "NoPoolSyncConfig", "SQLSpec", + "SQLStatement", "SyncArrowBulkOperationsMixin", "SyncDatabaseConfig", "SyncDriverAdapterProtocol", @@ -53,13 +52,15 @@ bound="Union[Union[AsyncDatabaseConfig[Any, Any, Any], NoPoolAsyncConfig[Any, Any]], SyncDatabaseConfig[Any, Any, Any], NoPoolSyncConfig[Any, Any]]", ) DriverT = TypeVar("DriverT", bound="Union[SyncDriverAdapterProtocol[Any], AsyncDriverAdapterProtocol[Any]]") - -# Regex to find :param style placeholders, avoiding those inside quotes -# Handles basic cases, might need refinement for complex SQL +# Regex to find :param or %(param)s style placeholders, skipping those inside quotes PARAM_REGEX = re.compile( - r"(?P\"(?:[^\"]|\"\")*\")|" # Double-quoted strings - r"(?P'(?:[^']|'')*')|" # Single-quoted strings - r"(?P[^:]):(?P[a-zA-Z_][a-zA-Z0-9_]*)" # :param placeholder + r""" + (?P"([^"]|\\")*") | # Double-quoted strings + (?P'([^']|\\')*') | # Single-quoted strings + : (?P[a-zA-Z_][a-zA-Z0-9_]*) | # :var_name + % \( (?P[a-zA-Z_][a-zA-Z0-9_]*) \) s # %(var_name)s + """, + re.VERBOSE, ) @@ -418,8 +419,8 @@ def close_pool( class CommonDriverAttributes(Generic[ConnectionT]): """Common attributes and methods for driver adapters.""" - param_style: str = "?" - """The parameter style placeholder supported by the underlying database driver (e.g., '?', '%s').""" + dialect: str + """The SQL dialect supported by the underlying database driver (e.g., 'postgres', 'mysql').""" connection: ConnectionT """The connection to the underlying database.""" __supports_arrow__: ClassVar[bool] = False @@ -446,182 +447,23 @@ def check_not_found(item_or_none: Optional[T] = None) -> T: raise NotFoundError(msg) return item_or_none - def _process_sql_statement(self, sql: str) -> str: - """Perform any preprocessing of the SQL query string if needed. - Default implementation returns the SQL unchanged. - - Args: - sql: The SQL query string. - - Returns: - The processed SQL query string. - """ - return sql - def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, /, **kwargs: Any ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": - """Process SQL query and parameters for DB-API execution. - - Uses sqlglot to parse named parameters (:name) if parameters is a dictionary, - and converts them to the driver's `param_style`. - Handles single value parameters by wrapping them in a tuple. + """Process SQL query and parameters using SQLStatement for validation and formatting. Args: sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, single value, or None). + parameters: Parameters for the query. **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: - A tuple containing the processed SQL string and the processed parameters - (tuple for named/single params, original list/tuple for positional, None if no params). - - Raises: - ValueError: If parameter validation fails (missing/extra keys for dicts, - mixing named/positional placeholders with dicts). - ImportError: If sqlglot is not installed. + A tuple containing the processed SQL query and parameters. """ - # 1. Handle None and kwargs - if parameters is None and not kwargs: - return self._process_sql_statement(sql), None - - # 2. Merge parameters with kwargs if parameters is a dict - parameters = {**parameters, **kwargs} if isinstance(parameters, dict) else kwargs if kwargs else parameters - - # 3. Handle dictionary parameters using sqlglot - if isinstance(parameters, dict): - if not parameters: - # Return early for empty dict - return self._process_sql_statement(sql), parameters - - # First check if there are any :param style placeholders using regex - regex_placeholders = [] - for match in PARAM_REGEX.finditer(sql): - if match.group("dquote") is not None or match.group("squote") is not None: - continue - var_name = match.group("var_name") - if var_name is not None: - regex_placeholders.append(var_name) - - try: - expression = sqlglot.parse_one(sql) - except Exception as e: - # If sqlglot parsing fails but regex found placeholders, use regex approach - if regex_placeholders: - # Use regex approach as fallback - processed_sql = sql - param_values = [] - for key, value in parameters.items(): - if key in regex_placeholders: - processed_sql = processed_sql.replace(f":{key}", self.param_style) - param_values.append(value) - - # Validate that all placeholders were found - if len(param_values) != len(regex_placeholders): - msg = f"Not all placeholders found in parameters: {set(regex_placeholders) - set(parameters.keys())}" - raise SQLParsingError(msg) from e - - return self._process_sql_statement(processed_sql), tuple(param_values) - - msg = f"sqlglot failed to parse SQL: {e}" - raise SQLParsingError(msg) from e - - placeholders = list(expression.find_all(exp.Parameter)) - placeholder_names: list[str] = [] - has_unnamed = False - for p in placeholders: - if p.name: - placeholder_names.append(p.name) - else: - has_unnamed = True # Found unnamed placeholder like '?' - - # If sqlglot didn't find any placeholders but regex did, use regex approach - if not placeholder_names and regex_placeholders: - processed_sql = sql - param_values = [] - for key, value in parameters.items(): - if key in regex_placeholders: - processed_sql = processed_sql.replace(f":{key}", self.param_style) - param_values.append(value) - - # Validate that all placeholders were found - if len(param_values) != len(regex_placeholders): - msg = ( - f"Not all placeholders found in parameters: {set(regex_placeholders) - set(parameters.keys())}" - ) - raise SQLParsingError(msg) - - return self._process_sql_statement(processed_sql), tuple(param_values) - - if has_unnamed: - msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') in the SQL query." - raise SQLParsingError(msg) - - if not placeholder_names: - # If no named placeholders found, but dict was provided, raise error. - # (We already handled the empty dict case above) - msg = "Dictionary parameters provided, but no named placeholders found in the SQL query." - raise SQLParsingError(msg) - - # Validation - provided_keys = set(parameters.keys()) - required_keys = set(placeholder_names) - - missing_keys = required_keys - provided_keys - if missing_keys: - msg = f"Named parameters found in SQL but not provided in parameters dictionary: {missing_keys}" - raise SQLParsingError(msg) - - extra_keys = provided_keys - required_keys - if extra_keys: - msg = f"Parameters provided but not found in SQL: {extra_keys}" - raise SQLParsingError(msg) # Strict check - - # Build ordered tuple of parameters - ordered_params = tuple(parameters[name] for name in placeholder_names) - - # Replace :name with self.param_style using regex for safety - processed_sql = "" - last_end = 0 - params_iter = iter(placeholder_names) # Ensure order correctness during replacement - - for match in PARAM_REGEX.finditer(sql): - if match.group("dquote") is not None or match.group("squote") is not None: - processed_sql += sql[last_end : match.end()] - last_end = match.end() - continue - - var_name = match.group("var_name") - if var_name is None: - processed_sql += sql[last_end : match.end()] - last_end = match.end() - continue - - expected_param = next(params_iter, None) - if var_name != expected_param: - msg = f"Internal parameter processing mismatch: Regex found ':{var_name}' but expected ':{expected_param}' based on sqlglot parse order." - raise SQLParsingError(msg) - - # Replace :param with param_style - start_replace = match.start("var_name") - 1 # Include the ':' - processed_sql += sql[last_end:start_replace] + self.param_style - last_end = match.end("var_name") - - processed_sql += sql[last_end:] # Append remaining part - - final_sql = self._process_sql_statement(processed_sql) - return final_sql, ordered_params - - # 4. Handle list/tuple parameters (positional) - if isinstance(parameters, (list, tuple)): - # Let the underlying driver handle these directly - return self._process_sql_statement(sql), parameters - - # 5. Handle single value parameters - # If it wasn't None, dict, list, or tuple, it must be a single value - processed_params: tuple[Any, ...] = (parameters,) - # Assuming single value maps to a single positional placeholder. - return self._process_sql_statement(sql), processed_params + # Instantiate SQLStatement with parameters and kwargs for internal merging + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=self.dialect, kwargs=kwargs or None) + # Process uses the merged parameters internally + return stmt.process() class SyncArrowBulkOperationsMixin(Generic[ConnectionT]): @@ -629,9 +471,6 @@ class SyncArrowBulkOperationsMixin(Generic[ConnectionT]): __supports_arrow__: "ClassVar[bool]" = True - def __init__(self, connection: ConnectionT) -> None: - self.connection = connection - @abstractmethod def select_arrow( # pyright: ignore[reportUnknownParameterType] self, @@ -659,7 +498,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] class SyncDriverAdapterProtocol(CommonDriverAttributes[ConnectionT], ABC, Generic[ConnectionT]): connection: "ConnectionT" - def __init__(self, connection: "ConnectionT") -> None: + def __init__(self, connection: "ConnectionT", **kwargs: Any) -> None: self.connection = connection @abstractmethod diff --git a/sqlspec/exceptions.py b/sqlspec/exceptions.py index e78eea62e..73eac9e32 100644 --- a/sqlspec/exceptions.py +++ b/sqlspec/exceptions.py @@ -1,3 +1,5 @@ +from collections.abc import Generator +from contextlib import contextmanager from typing import Any, Optional __all__ = ( @@ -6,7 +8,9 @@ "MissingDependencyError", "MultipleResultsFoundError", "NotFoundError", + "ParameterStyleMismatchError", "RepositoryError", + "SQLParsingError", "SQLSpecError", "SerializationError", ) @@ -74,6 +78,20 @@ def __init__(self, message: Optional[str] = None) -> None: super().__init__(message) +class ParameterStyleMismatchError(SQLSpecError): + """Error when parameter style doesn't match SQL placeholder style. + + This exception is raised when there's a mismatch between the parameter type + (dictionary, tuple, etc.) and the placeholder style in the SQL query + (named, positional, etc.). + """ + + def __init__(self, message: Optional[str] = None) -> None: + if message is None: + message = "Parameter style mismatch: dictionary parameters provided but no named placeholders found in SQL." + super().__init__(message) + + class ImproperConfigurationError(SQLSpecError): """Improper Configuration error. @@ -99,3 +117,15 @@ class NotFoundError(RepositoryError): class MultipleResultsFoundError(RepositoryError): """A single database result was required but more than one were found.""" + + +@contextmanager +def wrap_exceptions(wrap_exceptions: bool = True) -> Generator[None, None, None]: + try: + yield + + except Exception as exc: + if wrap_exceptions is False: + raise + msg = "An error occurred during the operation." + raise RepositoryError(detail=msg) from exc diff --git a/sqlspec/statement.py b/sqlspec/statement.py new file mode 100644 index 000000000..a02896d50 --- /dev/null +++ b/sqlspec/statement.py @@ -0,0 +1,373 @@ +# ruff: noqa: RUF100, PLR6301, PLR0912, PLR0915, C901, PLR0911, PLR0914, N806 +import logging +import re +from dataclasses import dataclass +from functools import cached_property +from typing import ( + Any, + Optional, + Union, +) + +import sqlglot +from sqlglot import exp + +from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError +from sqlspec.typing import StatementParameterType + +__all__ = ("SQLStatement",) + +logger = logging.getLogger("sqlspec") + +# Regex to find :param style placeholders, skipping those inside quotes or SQL comments +# Adapted from previous version in psycopg adapter +PARAM_REGEX = re.compile( + r"""(?"(?:[^"]|"")*") | # Double-quoted strings (support SQL standard escaping "") + (?P'(?:[^']|'')*') | # Single-quoted strings (support SQL standard escaping '') + (?P--.*?\n|\/\*.*?\*\/) | # SQL comments (single line or multi-line) + : (?P[a-zA-Z_][a-zA-Z0-9_]*) # :var_name identifier + ) + """, + re.VERBOSE | re.DOTALL, +) + + +@dataclass() +class SQLStatement: + """An immutable representation of a SQL statement with its parameters. + + This class encapsulates the SQL statement and its parameters, providing + a clean interface for parameter binding and SQL statement formatting. + """ + + dialect: str + """The SQL dialect to use for parsing (e.g., 'postgres', 'mysql'). Defaults to 'postgres' if None.""" + sql: str + """The raw SQL statement.""" + parameters: Optional[StatementParameterType] = None + """The parameters for the SQL statement.""" + kwargs: Optional[dict[str, Any]] = None + """Keyword arguments passed for parameter binding.""" + + _merged_parameters: Optional[Union[StatementParameterType, dict[str, Any]]] = None + + def __post_init__(self) -> None: + """Merge parameters and kwargs after initialization.""" + merged_params = self.parameters + + if self.kwargs: + if merged_params is None: + merged_params = self.kwargs + elif isinstance(merged_params, dict): + # Merge kwargs into parameters dict, kwargs take precedence + merged_params = {**merged_params, **self.kwargs} + else: + # If parameters is sequence or scalar, kwargs replace it + # Consider adding a warning here if this behavior is surprising + merged_params = self.kwargs + + self._merged_parameters = merged_params + + def process(self) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Process the SQL statement and merged parameters for execution. + + Returns: + A tuple containing the processed SQL string and the processed parameters + ready for database driver execution. + + Raises: + SQLParsingError: If the SQL statement contains parameter placeholders, but no parameters were provided. + + Returns: + A tuple containing the processed SQL string and the processed parameters + ready for database driver execution. + """ + if self._merged_parameters is None: + # Validate that the SQL doesn't expect parameters if none were provided + # Parse ONLY if we need to validate + try: # Add try/except in case parsing fails even here + expression = self._parse_sql() + except SQLParsingError: + # If parsing fails, we can't validate, but maybe that's okay if no params were passed? + # Log a warning? For now, let the original error propagate if needed. + # Or, maybe assume it's okay if _merged_parameters is None? + # Let's re-raise for now, as unparsable SQL is usually bad. + logger.warning("SQL statement is unparsable: %s", self.sql) + return self.sql, None + if list(expression.find_all(exp.Parameter)): + msg = "SQL statement contains parameter placeholders, but no parameters were provided." + raise SQLParsingError(msg) + return self.sql, None + + if isinstance(self._merged_parameters, dict): + # Pass only the dict, parsing happens inside + return self._process_dict_params(self._merged_parameters) + + if isinstance(self._merged_parameters, (tuple, list)): + # Pass only the sequence, parsing happens inside if needed for validation + return self._process_sequence_params(self._merged_parameters) + + # Assume it's a single scalar value otherwise + # Pass only the value, parsing happens inside for validation + return self._process_scalar_param(self._merged_parameters) + + def _parse_sql(self) -> exp.Expression: + """Parse the SQL using sqlglot. + + Raises: + SQLParsingError: If the SQL statement cannot be parsed. + + Returns: + The parsed SQL expression. + """ + parse_dialect = self.dialect or "postgres" + try: + read_dialect = parse_dialect or None + return sqlglot.parse_one(self.sql, read=read_dialect) + except Exception as e: + # Ensure the original sqlglot error message is included + error_detail = str(e) + msg = f"Failed to parse SQL with dialect '{parse_dialect or 'auto-detected'}': {error_detail}\nSQL: {self.sql}" + raise SQLParsingError(msg) from e + + def _process_dict_params( + self, + parameter_dict: dict[str, Any], + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Processes dictionary parameters based on dialect capabilities. + + Raises: + ParameterStyleMismatchError: If the SQL statement contains unnamed placeholders (e.g., '?') in the SQL query. + SQLParsingError: If the SQL statement contains named parameters, but no parameters were provided. + + Returns: + A tuple containing the processed SQL string and the processed parameters + ready for database driver execution. + """ + # Attempt to parse with sqlglot first (for other dialects like postgres, mysql) + named_sql_params: Optional[list[exp.Parameter]] = None + unnamed_sql_params: Optional[list[exp.Parameter]] = None + sqlglot_parsed_ok = False + # --- Dialect-Specific Bypasses for Native Handling --- + if self.dialect == "sqlite": # Handles :name natively + return self.sql, parameter_dict + + # Add bypass for postgres handled by specific adapters (e.g., asyncpg) + if self.dialect == "postgres": + # The adapter (e.g., asyncpg) will handle :name -> $n conversion. + # SQLStatement just validates parameters against the original SQL here. + # Perform validation using regex if sqlglot parsing fails, otherwise use sqlglot. + try: + expression = self._parse_sql() + sql_params = list(expression.find_all(exp.Parameter)) + named_sql_params = [p for p in sql_params if p.name] + unnamed_sql_params = [p for p in sql_params if not p.name] + + if unnamed_sql_params: + msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') found by sqlglot for postgres." + raise ParameterStyleMismatchError(msg) + + # Validate keys using sqlglot results + required_keys = {p.name for p in named_sql_params} + provided_keys = set(parameter_dict.keys()) + missing_keys = required_keys - provided_keys + if missing_keys: + msg = ( + f"Named parameters found in SQL (via sqlglot) but not provided: {missing_keys}. SQL: {self.sql}" + ) + raise SQLParsingError(msg) # noqa: TRY301 + # Allow extra keys + + except SQLParsingError as e: + logger.debug("SQLglot parsing failed for postgres dict params, attempting regex validation: %s", e) + # Regex validation fallback (without conversion) + postgres_found_params_regex: list[str] = [] + for match in PARAM_REGEX.finditer(self.sql): + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + if match.group("var_name"): + var_name = match.group("var_name") + postgres_found_params_regex.append(var_name) + if var_name not in parameter_dict: + msg = f"Named parameter ':{var_name}' found in SQL (via regex) but not provided. SQL: {self.sql}" + raise SQLParsingError(msg) # noqa: B904 + + if not postgres_found_params_regex and parameter_dict: + msg = f"Dictionary parameters provided, but no named placeholders (:name) found via regex. SQL: {self.sql}" + raise ParameterStyleMismatchError(msg) # noqa: B904 + # Allow extra keys with regex check too + + # Return the *original* SQL and the processed dict for the adapter to handle + return self.sql, parameter_dict + + if self.dialect == "duckdb": # Handles $name natively (and :name via driver? Check driver docs) + # Bypass sqlglot/regex checks. Trust user SQL ($name or ?) + dict for DuckDB driver. + # We lose :name -> $name conversion *if* sqlglot parsing fails, but avoid errors on valid $name SQL. + return self.sql, parameter_dict + # --- End Bypasses --- + + try: + expression = self._parse_sql() + sql_params = list(expression.find_all(exp.Parameter)) + named_sql_params = [p for p in sql_params if p.name] + unnamed_sql_params = [p for p in sql_params if not p.name] + sqlglot_parsed_ok = True + logger.debug("SQLglot parsed dict params successfully for: %s", self.sql) + except SQLParsingError as e: + logger.debug("SQLglot parsing failed for dict params, attempting regex fallback: %s", e) + # Proceed using regex fallback below + + # Check for unnamed placeholders if parsing worked + if sqlglot_parsed_ok and unnamed_sql_params: + msg = "Cannot use dictionary parameters with unnamed placeholders (e.g., '?') found by sqlglot." + raise ParameterStyleMismatchError(msg) + + # Determine if we need to use regex fallback + # Use fallback if: parsing failed OR (parsing worked BUT found no named params when a dict was provided) + use_regex_fallback = not sqlglot_parsed_ok or (not named_sql_params and parameter_dict) + + if use_regex_fallback: + # Regex fallback logic for :name -> self.param_style conversion + # ... (regex fallback code as implemented previously) ... + logger.debug("Using regex fallback for dict param processing: %s", self.sql) + # --- Regex Fallback Logic --- + regex_processed_sql_parts: list[str] = [] + ordered_params = [] + last_end = 0 + regex_found_params: list[str] = [] + + for match in PARAM_REGEX.finditer(self.sql): + # Skip matches that are comments or quoted strings + if match.group("dquote") or match.group("squote") or match.group("comment"): + continue + + if match.group("var_name"): + var_name = match.group("var_name") + regex_found_params.append(var_name) + # Get start and end from the match object for the :var_name part + # The var_name group itself doesn't include the leading :, so adjust start. + start = match.start("var_name") - 1 + end = match.end("var_name") + + if var_name not in parameter_dict: + msg = ( + f"Named parameter ':{var_name}' found in SQL (via regex) but not provided. SQL: {self.sql}" + ) + raise SQLParsingError(msg) + + regex_processed_sql_parts.extend((self.sql[last_end:start], self.param_style)) # Use target style + ordered_params.append(parameter_dict[var_name]) + last_end = end + + regex_processed_sql_parts.append(self.sql[last_end:]) + + # Validation with regex results + if not regex_found_params and parameter_dict: + msg = f"Dictionary parameters provided, but no named placeholders (e.g., :name) found via regex in the SQL query for dialect '{self.dialect}'. SQL: {self.sql}" + raise ParameterStyleMismatchError(msg) + + provided_keys = set(parameter_dict.keys()) + missing_keys = set(regex_found_params) - provided_keys # Should be caught above, but double check + if missing_keys: + msg = f"Named parameters found in SQL (via regex) but not provided: {missing_keys}. SQL: {self.sql}" + raise SQLParsingError(msg) + + extra_keys = provided_keys - set(regex_found_params) + if extra_keys: + # Allow extra keys + pass + + return "".join(regex_processed_sql_parts), tuple(ordered_params) + + # Sqlglot Logic (if parsing worked and found params) + # ... (sqlglot logic as implemented previously, including :name -> %s conversion) ... + logger.debug("Using sqlglot results for dict param processing: %s", self.sql) + + # Ensure named_sql_params is iterable, default to empty list if None (shouldn't happen ideally) + active_named_params = named_sql_params or [] + + if not active_named_params and not parameter_dict: + # No SQL params found by sqlglot, no provided params dict -> OK + return self.sql, () + + # Validation with sqlglot results + required_keys = {p.name for p in active_named_params} # Use active_named_params + provided_keys = set(parameter_dict.keys()) + + missing_keys = required_keys - provided_keys + if missing_keys: + msg = f"Named parameters found in SQL (via sqlglot) but not provided: {missing_keys}. SQL: {self.sql}" + raise SQLParsingError(msg) + + extra_keys = provided_keys - required_keys + if extra_keys: + pass # Allow extra keys + + # Note: DuckDB handled by bypass above if sqlglot fails. + # This block handles successful sqlglot parse for other dialects. + # We don't need the specific DuckDB $name conversion here anymore, + # as the bypass handles the native $name case. + # The general logic converts :name -> self.param_style for dialects like postgres. + # if self.dialect == "duckdb": ... (Removed specific block here) + + # For other dialects requiring positional conversion (using sqlglot param info): + sqlglot_processed_parts: list[str] = [] + ordered_params = [] + last_end = 0 + for param in active_named_params: # Use active_named_params + start = param.this.this.start + end = param.this.this.end + sqlglot_processed_parts.extend((self.sql[last_end:start], self.param_style)) + ordered_params.append(parameter_dict[param.name]) + last_end = end + sqlglot_processed_parts.append(self.sql[last_end:]) + return "".join(sqlglot_processed_parts), tuple(ordered_params) + + def _process_sequence_params( + self, params: Union[tuple[Any, ...], list[Any]] + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Processes a sequence of parameters. + + Returns: + A tuple containing the processed SQL string and the processed parameters + ready for database driver execution. + """ + return self.sql, params + + def _process_scalar_param( + self, param_value: Any + ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": + """Processes a single scalar parameter value. + + Returns: + A tuple containing the processed SQL string and the processed parameters + ready for database driver execution. + """ + return self.sql, (param_value,) + + @cached_property + def param_style(self) -> str: + """Get the parameter style based on the dialect. + + Returns: + The parameter style placeholder for the dialect. + """ + dialect = self.dialect + + # Map dialects to parameter styles for placeholder replacement + # Note: Used when converting named params (:name) for dialects needing positional. + # Dialects supporting named params natively (SQLite, DuckDB) are handled via bypasses. + dialect_to_param_style = { + "postgres": "%s", + "mysql": "%s", + "oracle": ":1", + "mssql": "?", + "bigquery": "?", + "snowflake": "?", + "cockroach": "%s", + "db2": "?", + } + # Default to '?' for unknown/unhandled dialects or when dialect=None is forced + return dialect_to_param_style.get(dialect, "?") diff --git a/sqlspec/typing.py b/sqlspec/typing.py index bb2a1cf8c..ee0aae7a1 100644 --- a/sqlspec/typing.py +++ b/sqlspec/typing.py @@ -79,7 +79,7 @@ - :class:`DTOData`[:type:`list[ModelT]`] """ -StatementParameterType: TypeAlias = "Union[dict[str, Any], list[Any], tuple[Any, ...], None]" +StatementParameterType: TypeAlias = "Union[Any, dict[str, Any], list[Any], tuple[Any, ...], None]" """Type alias for parameter types. Represents: diff --git a/tests/fixtures/sql_utils.py b/tests/fixtures/sql_utils.py index ca69d5265..b548b2db8 100644 --- a/tests/fixtures/sql_utils.py +++ b/tests/fixtures/sql_utils.py @@ -13,11 +13,13 @@ def format_placeholder(field_name: str, style: str, dialect: Optional[str] = Non The formatted placeholder string. """ if style == "tuple_binds": - if dialect in ["sqlite", "duckdb", "aiosqlite"]: + if dialect in {"sqlite", "duckdb", "aiosqlite"}: return "?" # Default to Postgres/BigQuery style return "%s" - if dialect in ["sqlite", "duckdb", "aiosqlite"]: + if dialect == "duckdb": + return f"${field_name}" + if dialect in {"sqlite", "aiosqlite"}: return f":{field_name}" # For postgres and similar return f"%({field_name})s" diff --git a/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py b/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py index 23912f19d..43bd56ece 100644 --- a/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py +++ b/tests/integration/test_adapters/test_adbc/test_driver_duckdb.py @@ -276,3 +276,169 @@ def test_driver_select_arrow(adbc_session: AdbcConfig) -> None: assert arrow_table.column("id").to_pylist() == [1] driver.execute_script("DROP TABLE IF EXISTS test_table") driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_named_params_with_scalar(adbc_session: AdbcConfig) -> None: + """Test that scalar parameters work with named parameters in SQL.""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50) + ); + """ + driver.execute_script(sql) + + # Insert test record using positional parameter with scalar value + insert_sql = """ + INSERT INTO test_table (name) + VALUES (?) + """ + driver.insert_update_delete(insert_sql, "test_name") + + # Select and verify + select_sql = "SELECT name FROM test_table WHERE name = ?" + results = driver.select(select_sql, "test_name") + assert len(results) == 1 + assert results[0]["name"] == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_named_params_with_tuple(adbc_session: AdbcConfig) -> None: + """Test that tuple parameters work with named parameters in SQL.""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50), + age INTEGER + ); + """ + driver.execute_script(sql) + + # Insert test record using positional parameters with tuple values + insert_sql = """ + INSERT INTO test_table (name, age) + VALUES (?, ?) + """ + driver.insert_update_delete(insert_sql, ("test_name", 30)) + + # Select and verify + select_sql = "SELECT name, age FROM test_table WHERE name = ? AND age = ?" + results = driver.select(select_sql, ("test_name", 30)) + assert len(results) == 1 + assert results[0]["name"] == "test_name" + assert results[0]["age"] == 30 + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_native_named_params(adbc_session: AdbcConfig) -> None: + """Test DuckDB's native named parameter style ($name).""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50) + ); + """ + driver.execute_script(sql) + + # Insert test record using native $name style + insert_sql = """ + INSERT INTO test_table (name) + VALUES ($name) + """ + driver.insert_update_delete(insert_sql, {"name": "native_name"}) + + # Select and verify + select_sql = "SELECT name FROM test_table WHERE name = $name" + results = driver.select(select_sql, {"name": "native_name"}) + assert len(results) == 1 + assert results[0]["name"] == "native_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_native_positional_params(adbc_session: AdbcConfig) -> None: + """Test DuckDB's native positional parameter style ($1, $2, etc.).""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50), + age INTEGER + ); + """ + driver.execute_script(sql) + + # Insert test record using native $1 style + insert_sql = """ + INSERT INTO test_table (name, age) + VALUES ($1, $2) + """ + driver.insert_update_delete(insert_sql, ("native_pos", 30)) + + # Select and verify + select_sql = "SELECT name, age FROM test_table WHERE name = $1 AND age = $2" + results = driver.select(select_sql, ("native_pos", 30)) + assert len(results) == 1 + assert results[0]["name"] == "native_pos" + assert results[0]["age"] == 30 + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("duckdb") +def test_driver_native_auto_incremented_params(adbc_session: AdbcConfig) -> None: + """Test DuckDB's native auto-incremented parameter style (?).""" + with adbc_session.provide_session() as driver: + # Create test table + create_sequence_sql = "CREATE SEQUENCE test_table_id_seq START 1;" + driver.execute_script(create_sequence_sql) + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY DEFAULT nextval('test_table_id_seq'), + name VARCHAR(50), + age INTEGER + ); + """ + driver.execute_script(sql) + + # Insert test record using native ? style + insert_sql = """ + INSERT INTO test_table (name, age) + VALUES (?, ?) + """ + driver.insert_update_delete(insert_sql, ("native_auto", 35)) + + # Select and verify + select_sql = "SELECT name, age FROM test_table WHERE name = ? AND age = ?" + results = driver.select(select_sql, ("native_auto", 35)) + assert len(results) == 1 + assert results[0]["name"] == "native_auto" + assert results[0]["age"] == 35 + driver.execute_script("DROP TABLE IF EXISTS test_table") + driver.execute_script("DROP SEQUENCE IF EXISTS test_table_id_seq") diff --git a/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py b/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py index fa7c5ea7e..8e85381d8 100644 --- a/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py +++ b/tests/integration/test_adapters/test_adbc/test_driver_sqlite.py @@ -250,3 +250,61 @@ def test_driver_select_arrow(adbc_session: AdbcConfig) -> None: # Assuming id is 1 for the inserted record assert arrow_table.column("id").to_pylist() == [1] driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("sqlite") +def test_driver_named_params_with_scalar(adbc_session: AdbcConfig) -> None: + """Test that scalar parameters work with named parameters in SQL.""" + with adbc_session.provide_session() as driver: + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(50) + ); + """ + driver.execute_script(sql) + + # Insert test record using named parameter with scalar value + insert_sql = """ + INSERT INTO test_table (name) + VALUES (:name) + """ + driver.insert_update_delete(insert_sql, "test_name") + + # Select and verify + select_sql = "SELECT name FROM test_table WHERE name = :name" + results = driver.select(select_sql, "test_name") + assert len(results) == 1 + assert results[0]["name"] == "test_name" + driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@xfail_if_driver_missing +@pytest.mark.xdist_group("sqlite") +def test_driver_named_params_with_tuple(adbc_session: AdbcConfig) -> None: + """Test that tuple parameters work with named parameters in SQL.""" + with adbc_session.provide_session() as driver: + sql = """ + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(50), + age INTEGER + ); + """ + driver.execute_script(sql) + + # Insert test record using named parameters with tuple values + insert_sql = """ + INSERT INTO test_table (name, age) + VALUES (:name, :age) + """ + driver.insert_update_delete(insert_sql, ("test_name", 30)) + + # Select and verify + select_sql = "SELECT name, age FROM test_table WHERE name = :name AND age = :age" + results = driver.select(select_sql, ("test_name", 30)) + assert len(results) == 1 + assert results[0]["name"] == "test_name" + assert results[0]["age"] == 30 + driver.execute_script("DROP TABLE IF EXISTS test_table") diff --git a/tests/integration/test_adapters/test_asyncpg/__init__.py b/tests/integration/test_adapters/test_asyncpg/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_adapters/test_asyncpg/test_connection.py b/tests/integration/test_adapters/test_asyncpg/test_connection.py new file mode 100644 index 000000000..31ba70171 --- /dev/null +++ b/tests/integration/test_adapters/test_asyncpg/test_connection.py @@ -0,0 +1,42 @@ +import pytest +from pytest_databases.docker.postgres import PostgresService + +from sqlspec.adapters.asyncpg import AsyncpgConfig, AsyncpgPoolConfig + + +@pytest.mark.xdist_group("postgres") +async def test_async_connection(postgres_service: PostgresService) -> None: + """Test asyncpg connection components.""" + # Test direct connection + async_config = AsyncpgConfig( + pool_config=AsyncpgPoolConfig( + dsn=f"postgres://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}", + ), + ) + + conn = await async_config.create_connection() + try: + assert conn is not None + # Test basic query + result = await conn.fetchval("SELECT 1") + assert result == 1 + finally: + await conn.close() + + # Test connection pool + pool_config = AsyncpgPoolConfig( + dsn=f"postgres://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}", + min_size=1, + max_size=5, + ) + another_config = AsyncpgConfig(pool_config=pool_config) + # Ensure the pool is created before use if not explicitly managed elsewhere + await another_config.create_pool() + try: + async with another_config.provide_connection() as conn: + assert conn is not None + # Test basic query + result = await conn.fetchval("SELECT 1") + assert result == 1 + finally: + await another_config.close_pool() diff --git a/tests/integration/test_adapters/test_asyncpg/test_driver.py b/tests/integration/test_adapters/test_asyncpg/test_driver.py new file mode 100644 index 000000000..4e62cedf8 --- /dev/null +++ b/tests/integration/test_adapters/test_asyncpg/test_driver.py @@ -0,0 +1,275 @@ +"""Test Asyncpg driver implementation.""" + +from __future__ import annotations + +from typing import Any, Literal + +import pytest +from pytest_databases.docker.postgres import PostgresService + +from sqlspec.adapters.asyncpg import AsyncpgConfig, AsyncpgPoolConfig + +ParamStyle = Literal["tuple_binds", "dict_binds"] + + +@pytest.fixture +def asyncpg_config(postgres_service: PostgresService) -> AsyncpgConfig: + """Create an Asyncpg configuration. + + Args: + postgres_service: PostgreSQL service fixture. + + Returns: + Configured Asyncpg session config. + """ + return AsyncpgConfig( + pool_config=AsyncpgPoolConfig( + dsn=f"postgres://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}", + min_size=1, # Add min_size to avoid pool deadlock issues in tests + max_size=5, + ) + ) + + +@pytest.mark.parametrize( + ("params", "style"), + [ + pytest.param(("test_name",), "tuple_binds", id="tuple_binds"), + pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), + ], +) +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_async_insert_returning(asyncpg_config: AsyncpgConfig, params: Any, style: ParamStyle) -> None: + """Test async insert returning functionality with different parameter styles.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ); + """ + await driver.execute_script(sql) + + # Use appropriate SQL for each style (sqlspec driver handles conversion to $1, $2...) + if style == "tuple_binds": + sql = """ + INSERT INTO test_table (name) + VALUES (?) + RETURNING * + """ + else: # dict_binds + sql = """ + INSERT INTO test_table (name) + VALUES (:name) + RETURNING * + """ + + try: + result = await driver.insert_update_delete_returning(sql, params) + assert result is not None + assert result["name"] == "test_name" + assert result["id"] is not None + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@pytest.mark.parametrize( + ("params", "style"), + [ + pytest.param(("test_name",), "tuple_binds", id="tuple_binds"), + pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), + ], +) +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_async_select(asyncpg_config: AsyncpgConfig, params: Any, style: ParamStyle) -> None: + """Test async select functionality with different parameter styles.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + # Create test table + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ); + """ + await driver.execute_script(sql) + + # Insert test record + if style == "tuple_binds": + insert_sql = """ + INSERT INTO test_table (name) + VALUES (?) + """ + else: # dict_binds + insert_sql = """ + INSERT INTO test_table (name) + VALUES (:name) + """ + await driver.insert_update_delete(insert_sql, params) + + # Select and verify + if style == "tuple_binds": + select_sql = """ + SELECT name FROM test_table WHERE name = ? + """ + else: # dict_binds + select_sql = """ + SELECT name FROM test_table WHERE name = :name + """ + try: + results = await driver.select(select_sql, params) + assert len(results) == 1 + assert results[0]["name"] == "test_name" + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@pytest.mark.parametrize( + ("params", "style"), + [ + pytest.param(("test_name",), "tuple_binds", id="tuple_binds"), + pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), + ], +) +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_async_select_value(asyncpg_config: AsyncpgConfig, params: Any, style: ParamStyle) -> None: + """Test async select_value functionality with different parameter styles.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + # Create test table + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ); + """ + await driver.execute_script(sql) + + # Insert test record + if style == "tuple_binds": + insert_sql = """ + INSERT INTO test_table (name) + VALUES (?) + """ + else: # dict_binds + insert_sql = """ + INSERT INTO test_table (name) + VALUES (:name) + """ + await driver.insert_update_delete(insert_sql, params) + + # Get literal string to test with select_value + # Use a literal query to test select_value + select_sql = "SELECT 'test_name' AS test_name" + + try: + # Don't pass parameters with a literal query that has no placeholders + value = await driver.select_value(select_sql) + assert value == "test_name" + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_insert(asyncpg_config: AsyncpgConfig) -> None: + """Test inserting data.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ) + """ + await driver.execute_script(sql) + + insert_sql = "INSERT INTO test_table (name) VALUES (?)" + try: + row_count = await driver.insert_update_delete(insert_sql, ("test",)) + assert row_count == 1 + + # Verify insertion + select_sql = "SELECT COUNT(*) FROM test_table WHERE name = ?" + count = await driver.select_value(select_sql, ("test",)) + assert count == 1 + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table") + + +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_select(asyncpg_config: AsyncpgConfig) -> None: + """Test selecting data.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + # Create and populate test table + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ) + """ + await driver.execute_script(sql) + + insert_sql = "INSERT INTO test_table (name) VALUES (?)" + await driver.insert_update_delete(insert_sql, ("test",)) + + # Select and verify + select_sql = "SELECT name FROM test_table WHERE id = ?" + try: + results = await driver.select(select_sql, (1,)) + assert len(results) == 1 + assert results[0]["name"] == "test" + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table") + + +# Asyncpg uses positional ($n) parameters internally. +# The sqlspec driver converts '?' (tuple) and ':name' (dict) styles. +# We test these two styles as they are what the user interacts with via sqlspec. +@pytest.mark.parametrize( + "param_style", + [ + "tuple_binds", # Corresponds to '?' in SQL passed to sqlspec + "dict_binds", # Corresponds to ':name' in SQL passed to sqlspec + ], +) +@pytest.mark.xdist_group("postgres") +@pytest.mark.asyncio +async def test_param_styles(asyncpg_config: AsyncpgConfig, param_style: str) -> None: + """Test different parameter styles expected by sqlspec.""" + async with asyncpg_config.provide_session() as driver: + await driver.execute_script("DROP TABLE IF EXISTS test_table") # Ensure clean state + # Create test table + sql = """ + CREATE TABLE test_table ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ) + """ + await driver.execute_script(sql) + + # Insert test record based on param style + if param_style == "tuple_binds": + insert_sql = "INSERT INTO test_table (name) VALUES (?)" + params: Any = ("test",) + else: # dict_binds + insert_sql = "INSERT INTO test_table (name) VALUES (:name)" + params = {"name": "test"} + + try: + row_count = await driver.insert_update_delete(insert_sql, params) + assert row_count == 1 + + # Select and verify + select_sql = "SELECT name FROM test_table WHERE id = ?" + results = await driver.select(select_sql, (1,)) + assert len(results) == 1 + assert results[0]["name"] == "test" + finally: + await driver.execute_script("DROP TABLE IF EXISTS test_table")