11"""ADBC driver implementation for Arrow Database Connectivity.
22
3- Provides ADBC driver integration with multi-dialect database connections,
4- Arrow-native data handling with type coercion, parameter style conversion
5- for different database backends, and transaction management.
3+ Provides database connectivity through ADBC with support for multiple
4+ database dialects, parameter style conversion, and transaction management.
65"""
76
87import contextlib
98import datetime
109import decimal
1110from typing import TYPE_CHECKING , Any , Optional , cast
1211
12+ from adbc_driver_manager .dbapi import DatabaseError , IntegrityError , OperationalError , ProgrammingError
1313from sqlglot import exp
1414
1515from sqlspec .core .cache import get_cache_config
5353}
5454
5555
56- def _adbc_ast_transformer (expression : Any , parameters : Any ) -> tuple [ Any , Any ] :
57- """AST transformer for NULL parameter handling .
56+ def _count_placeholders (expression : Any ) -> int :
57+ """Count the number of unique parameter placeholders in a SQLGlot expression .
5858
59- For PostgreSQL, replaces NULL parameter placeholders with NULL literals
60- in the AST to prevent Arrow from inferring 'na' types which cause binding errors.
59+ For PostgreSQL ($1, $2) style: counts highest numbered parameter (e.g., $1, $1, $2 = 2)
60+ For QMARK (?) style: counts total occurrences (each ? is a separate parameter)
61+ For named (:name) style: counts unique parameter names
6162
6263 Args:
6364 expression: SQLGlot AST expression
64- parameters: Parameter values that may contain None
6565
6666 Returns:
67- Tuple of (modified_expression, cleaned_parameters)
67+ Number of unique parameter placeholders expected
6868 """
69- if not parameters :
70- return expression , parameters
69+ numeric_params = set () # For $1, $2 style
70+ qmark_count = 0 # For ? style
71+ named_params = set () # For :name style
72+
73+ def count_node (node : Any ) -> Any :
74+ nonlocal qmark_count
75+ if isinstance (node , exp .Parameter ):
76+ # PostgreSQL style: $1, $2, etc.
77+ param_str = str (node )
78+ if param_str .startswith ("$" ) and param_str [1 :].isdigit ():
79+ numeric_params .add (int (param_str [1 :]))
80+ elif ":" in param_str :
81+ # Named parameter: :name
82+ named_params .add (param_str )
83+ else :
84+ # Other parameter formats
85+ named_params .add (param_str )
86+ elif isinstance (node , exp .Placeholder ):
87+ # QMARK style: ?
88+ qmark_count += 1
89+ return node
90+
91+ expression .transform (count_node )
92+
93+ # Return the appropriate count based on parameter style detected
94+ if numeric_params :
95+ # PostgreSQL style: return highest numbered parameter
96+ return max (numeric_params )
97+ if named_params :
98+ # Named parameters: return count of unique names
99+ return len (named_params )
100+ # QMARK style: return total count
101+ return qmark_count
102+
103+
104+ def _is_execute_many_parameters (parameters : Any ) -> bool :
105+ """Check if parameters are in execute_many format (list/tuple of lists/tuples)."""
106+ return isinstance (parameters , (list , tuple )) and len (parameters ) > 0 and isinstance (parameters [0 ], (list , tuple ))
107+
71108
109+ def _validate_parameter_counts (expression : Any , parameters : Any , dialect : str ) -> None :
110+ """Validate parameter count against placeholder count in SQL."""
111+ placeholder_count = _count_placeholders (expression )
112+ is_execute_many = _is_execute_many_parameters (parameters )
113+
114+ if is_execute_many :
115+ # For execute_many, validate each inner parameter set
116+ for i , param_set in enumerate (parameters ):
117+ param_count = len (param_set ) if isinstance (param_set , (list , tuple )) else 0
118+ if param_count != placeholder_count :
119+ msg = f"Parameter count mismatch in set { i } : { param_count } parameters provided but { placeholder_count } placeholders in SQL (dialect: { dialect } )"
120+ raise SQLSpecError (msg )
121+ else :
122+ # For single execution, validate the parameter set directly
123+ param_count = (
124+ len (parameters )
125+ if isinstance (parameters , (list , tuple ))
126+ else len (parameters )
127+ if isinstance (parameters , dict )
128+ else 0
129+ )
130+
131+ if param_count != placeholder_count :
132+ msg = f"Parameter count mismatch: { param_count } parameters provided but { placeholder_count } placeholders in SQL (dialect: { dialect } )"
133+ raise SQLSpecError (msg )
134+
135+
136+ def _find_null_positions (parameters : Any ) -> set [int ]:
137+ """Find positions of None values in parameters for single execution."""
72138 null_positions = set ()
73139 if isinstance (parameters , (list , tuple )):
74140 for i , param in enumerate (parameters ):
@@ -83,7 +149,37 @@ def _adbc_ast_transformer(expression: Any, parameters: Any) -> tuple[Any, Any]:
83149 null_positions .add (param_num - 1 )
84150 except ValueError :
85151 pass
152+ return null_positions
153+
154+
155+ def _adbc_ast_transformer (expression : Any , parameters : Any , dialect : str = "postgres" ) -> tuple [Any , Any ]:
156+ """Transform AST to handle NULL parameters.
86157
158+ Replaces NULL parameter placeholders with NULL literals in the AST
159+ to prevent Arrow from inferring 'na' types which cause binding errors.
160+ Validates parameter count before transformation.
161+
162+ Args:
163+ expression: SQLGlot AST expression parsed with proper dialect
164+ parameters: Parameter values that may contain None
165+ dialect: SQLGlot dialect used for parsing (default: "postgres")
166+
167+ Returns:
168+ Tuple of (modified_expression, cleaned_parameters)
169+ """
170+ if not parameters :
171+ return expression , parameters
172+
173+ # Validate parameter count before transformation
174+ _validate_parameter_counts (expression , parameters , dialect )
175+
176+ # For execute_many operations, skip AST transformation as different parameter
177+ # sets may have None values in different positions, making transformation complex
178+ if _is_execute_many_parameters (parameters ):
179+ return expression , parameters
180+
181+ # Find positions of None values for single execution
182+ null_positions = _find_null_positions (parameters )
87183 if not null_positions :
88184 return expression , parameters
89185
@@ -183,14 +279,28 @@ def get_adbc_statement_config(detected_dialect: str) -> StatementConfig:
183279
184280
185281def _convert_array_for_postgres_adbc (value : Any ) -> Any :
186- """Convert array values for PostgreSQL compatibility."""
282+ """Convert array values for PostgreSQL compatibility.
283+
284+ Args:
285+ value: Value to convert
286+
287+ Returns:
288+ Converted value (tuples become lists)
289+ """
187290 if isinstance (value , tuple ):
188291 return list (value )
189292 return value
190293
191294
192295def get_type_coercion_map (dialect : str ) -> "dict[type, Any]" :
193- """Get type coercion map for Arrow type handling."""
296+ """Get type coercion map for Arrow type handling.
297+
298+ Args:
299+ dialect: Database dialect name
300+
301+ Returns:
302+ Mapping of Python types to conversion functions
303+ """
194304 type_map = {
195305 datetime .datetime : lambda x : x ,
196306 datetime .date : lambda x : x ,
@@ -245,8 +355,6 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
245355 return
246356
247357 try :
248- from adbc_driver_manager .dbapi import DatabaseError , IntegrityError , OperationalError , ProgrammingError
249-
250358 if issubclass (exc_type , IntegrityError ):
251359 e = exc_val
252360 msg = f"Integrity constraint violation: { e } "
@@ -282,9 +390,8 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
282390class AdbcDriver (SyncDriverAdapterBase ):
283391 """ADBC driver for Arrow Database Connectivity.
284392
285- Provides database connectivity through ADBC with multi-database dialect
286- support, Arrow-native data handling with type coercion, parameter style
287- conversion for different backends, and transaction management.
393+ Provides database connectivity through ADBC with support for multiple
394+ database dialects, parameter style conversion, and transaction management.
288395 """
289396
290397 __slots__ = ("_detected_dialect" , "dialect" )
@@ -309,15 +416,26 @@ def __init__(
309416
310417 @staticmethod
311418 def _ensure_pyarrow_installed () -> None :
312- """Ensure PyArrow is installed."""
419+ """Ensure PyArrow is installed.
420+
421+ Raises:
422+ MissingDependencyError: If PyArrow is not installed
423+ """
313424 from sqlspec .typing import PYARROW_INSTALLED
314425
315426 if not PYARROW_INSTALLED :
316427 raise MissingDependencyError (package = "pyarrow" , install_package = "arrow" )
317428
318429 @staticmethod
319430 def _get_dialect (connection : "AdbcConnection" ) -> str :
320- """Detect database dialect from connection information."""
431+ """Detect database dialect from connection information.
432+
433+ Args:
434+ connection: ADBC connection
435+
436+ Returns:
437+ Detected dialect name (defaults to 'postgres')
438+ """
321439 try :
322440 driver_info = connection .adbc_get_info ()
323441 vendor_name = driver_info .get ("vendor_name" , "" ).lower ()
@@ -334,31 +452,53 @@ def _get_dialect(connection: "AdbcConnection") -> str:
334452 return "postgres"
335453
336454 def _handle_postgres_rollback (self , cursor : "Cursor" ) -> None :
337- """Execute rollback for PostgreSQL after transaction failure."""
455+ """Execute rollback for PostgreSQL after transaction failure.
456+
457+ Args:
458+ cursor: Database cursor
459+ """
338460 if self .dialect == "postgres" :
339461 with contextlib .suppress (Exception ):
340462 cursor .execute ("ROLLBACK" )
341463 logger .debug ("PostgreSQL rollback executed after transaction failure" )
342464
343465 def _handle_postgres_empty_parameters (self , parameters : Any ) -> Any :
344- """Process empty parameters for PostgreSQL compatibility."""
466+ """Process empty parameters for PostgreSQL compatibility.
467+
468+ Args:
469+ parameters: Parameter values
470+
471+ Returns:
472+ Processed parameters
473+ """
345474 if self .dialect == "postgres" and isinstance (parameters , dict ) and not parameters :
346475 return None
347476 return parameters
348477
349478 def with_cursor (self , connection : "AdbcConnection" ) -> "AdbcCursor" :
350- """Create context manager for cursor."""
479+ """Create context manager for cursor.
480+
481+ Args:
482+ connection: Database connection
483+
484+ Returns:
485+ Cursor context manager
486+ """
351487 return AdbcCursor (connection )
352488
353489 def handle_database_exceptions (self ) -> "AbstractContextManager[None]" :
354- """Handle database-specific exceptions and wrap them appropriately."""
490+ """Handle database-specific exceptions and wrap them appropriately.
491+
492+ Returns:
493+ Exception handler context manager
494+ """
355495 return AdbcExceptionHandler ()
356496
357497 def _try_special_handling (self , cursor : "Cursor" , statement : SQL ) -> "Optional[SQLResult]" :
358498 """Handle special operations.
359499
360500 Args:
361- cursor: Cursor object
501+ cursor: Database cursor
362502 statement: SQL statement to analyze
363503
364504 Returns:
@@ -368,7 +508,15 @@ def _try_special_handling(self, cursor: "Cursor", statement: SQL) -> "Optional[S
368508 return None
369509
370510 def _execute_many (self , cursor : "Cursor" , statement : SQL ) -> "ExecutionResult" :
371- """Execute SQL with multiple parameter sets."""
511+ """Execute SQL with multiple parameter sets.
512+
513+ Args:
514+ cursor: Database cursor
515+ statement: SQL statement to execute
516+
517+ Returns:
518+ Execution result with row counts
519+ """
372520 sql , prepared_parameters = self ._get_compiled_sql (statement , self .statement_config )
373521
374522 try :
@@ -398,7 +546,15 @@ def _execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
398546 return self .create_execution_result (cursor , rowcount_override = row_count , is_many_result = True )
399547
400548 def _execute_statement (self , cursor : "Cursor" , statement : SQL ) -> "ExecutionResult" :
401- """Execute single SQL statement."""
549+ """Execute single SQL statement.
550+
551+ Args:
552+ cursor: Database cursor
553+ statement: SQL statement to execute
554+
555+ Returns:
556+ Execution result with data or row count
557+ """
402558 sql , prepared_parameters = self ._get_compiled_sql (statement , self .statement_config )
403559
404560 try :
@@ -430,7 +586,15 @@ def _execute_statement(self, cursor: "Cursor", statement: SQL) -> "ExecutionResu
430586 return self .create_execution_result (cursor , rowcount_override = row_count )
431587
432588 def _execute_script (self , cursor : "Cursor" , statement : "SQL" ) -> "ExecutionResult" :
433- """Execute SQL script."""
589+ """Execute SQL script containing multiple statements.
590+
591+ Args:
592+ cursor: Database cursor
593+ statement: SQL script to execute
594+
595+ Returns:
596+ Execution result with statement counts
597+ """
434598 if statement .is_script :
435599 sql = statement ._raw_sql
436600 prepared_parameters : list [Any ] = []
0 commit comments