Skip to content

Commit 5a1e46e

Browse files
authored
fix: null parameter handling (#64)
Enhance handling of null parameters in SQL expressions to prevent binding errors and ensure parameter count validation.
1 parent 10157c9 commit 5a1e46e

File tree

170 files changed

+5814
-3440
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

170 files changed

+5814
-3440
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.12.9"
20+
rev: "v0.12.10"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

sqlspec/adapters/adbc/driver.py

Lines changed: 192 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""ADBC driver implementation for Arrow Database Connectivity.
22
3-
Provides ADBC driver integration with multi-dialect database connections,
4-
Arrow-native data handling with type coercion, parameter style conversion
5-
for different database backends, and transaction management.
3+
Provides database connectivity through ADBC with support for multiple
4+
database dialects, parameter style conversion, and transaction management.
65
"""
76

87
import contextlib
98
import datetime
109
import decimal
1110
from typing import TYPE_CHECKING, Any, Optional, cast
1211

12+
from adbc_driver_manager.dbapi import DatabaseError, IntegrityError, OperationalError, ProgrammingError
1313
from sqlglot import exp
1414

1515
from sqlspec.core.cache import get_cache_config
@@ -53,22 +53,88 @@
5353
}
5454

5555

56-
def _adbc_ast_transformer(expression: Any, parameters: Any) -> tuple[Any, Any]:
57-
"""AST transformer for NULL parameter handling.
56+
def _count_placeholders(expression: Any) -> int:
57+
"""Count the number of unique parameter placeholders in a SQLGlot expression.
5858
59-
For PostgreSQL, replaces NULL parameter placeholders with NULL literals
60-
in the AST to prevent Arrow from inferring 'na' types which cause binding errors.
59+
For PostgreSQL ($1, $2) style: counts highest numbered parameter (e.g., $1, $1, $2 = 2)
60+
For QMARK (?) style: counts total occurrences (each ? is a separate parameter)
61+
For named (:name) style: counts unique parameter names
6162
6263
Args:
6364
expression: SQLGlot AST expression
64-
parameters: Parameter values that may contain None
6565
6666
Returns:
67-
Tuple of (modified_expression, cleaned_parameters)
67+
Number of unique parameter placeholders expected
6868
"""
69-
if not parameters:
70-
return expression, parameters
69+
numeric_params = set() # For $1, $2 style
70+
qmark_count = 0 # For ? style
71+
named_params = set() # For :name style
72+
73+
def count_node(node: Any) -> Any:
74+
nonlocal qmark_count
75+
if isinstance(node, exp.Parameter):
76+
# PostgreSQL style: $1, $2, etc.
77+
param_str = str(node)
78+
if param_str.startswith("$") and param_str[1:].isdigit():
79+
numeric_params.add(int(param_str[1:]))
80+
elif ":" in param_str:
81+
# Named parameter: :name
82+
named_params.add(param_str)
83+
else:
84+
# Other parameter formats
85+
named_params.add(param_str)
86+
elif isinstance(node, exp.Placeholder):
87+
# QMARK style: ?
88+
qmark_count += 1
89+
return node
90+
91+
expression.transform(count_node)
92+
93+
# Return the appropriate count based on parameter style detected
94+
if numeric_params:
95+
# PostgreSQL style: return highest numbered parameter
96+
return max(numeric_params)
97+
if named_params:
98+
# Named parameters: return count of unique names
99+
return len(named_params)
100+
# QMARK style: return total count
101+
return qmark_count
102+
103+
104+
def _is_execute_many_parameters(parameters: Any) -> bool:
105+
"""Check if parameters are in execute_many format (list/tuple of lists/tuples)."""
106+
return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], (list, tuple))
107+
71108

109+
def _validate_parameter_counts(expression: Any, parameters: Any, dialect: str) -> None:
110+
"""Validate parameter count against placeholder count in SQL."""
111+
placeholder_count = _count_placeholders(expression)
112+
is_execute_many = _is_execute_many_parameters(parameters)
113+
114+
if is_execute_many:
115+
# For execute_many, validate each inner parameter set
116+
for i, param_set in enumerate(parameters):
117+
param_count = len(param_set) if isinstance(param_set, (list, tuple)) else 0
118+
if param_count != placeholder_count:
119+
msg = f"Parameter count mismatch in set {i}: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})"
120+
raise SQLSpecError(msg)
121+
else:
122+
# For single execution, validate the parameter set directly
123+
param_count = (
124+
len(parameters)
125+
if isinstance(parameters, (list, tuple))
126+
else len(parameters)
127+
if isinstance(parameters, dict)
128+
else 0
129+
)
130+
131+
if param_count != placeholder_count:
132+
msg = f"Parameter count mismatch: {param_count} parameters provided but {placeholder_count} placeholders in SQL (dialect: {dialect})"
133+
raise SQLSpecError(msg)
134+
135+
136+
def _find_null_positions(parameters: Any) -> set[int]:
137+
"""Find positions of None values in parameters for single execution."""
72138
null_positions = set()
73139
if isinstance(parameters, (list, tuple)):
74140
for i, param in enumerate(parameters):
@@ -83,7 +149,37 @@ def _adbc_ast_transformer(expression: Any, parameters: Any) -> tuple[Any, Any]:
83149
null_positions.add(param_num - 1)
84150
except ValueError:
85151
pass
152+
return null_positions
153+
154+
155+
def _adbc_ast_transformer(expression: Any, parameters: Any, dialect: str = "postgres") -> tuple[Any, Any]:
156+
"""Transform AST to handle NULL parameters.
86157
158+
Replaces NULL parameter placeholders with NULL literals in the AST
159+
to prevent Arrow from inferring 'na' types which cause binding errors.
160+
Validates parameter count before transformation.
161+
162+
Args:
163+
expression: SQLGlot AST expression parsed with proper dialect
164+
parameters: Parameter values that may contain None
165+
dialect: SQLGlot dialect used for parsing (default: "postgres")
166+
167+
Returns:
168+
Tuple of (modified_expression, cleaned_parameters)
169+
"""
170+
if not parameters:
171+
return expression, parameters
172+
173+
# Validate parameter count before transformation
174+
_validate_parameter_counts(expression, parameters, dialect)
175+
176+
# For execute_many operations, skip AST transformation as different parameter
177+
# sets may have None values in different positions, making transformation complex
178+
if _is_execute_many_parameters(parameters):
179+
return expression, parameters
180+
181+
# Find positions of None values for single execution
182+
null_positions = _find_null_positions(parameters)
87183
if not null_positions:
88184
return expression, parameters
89185

@@ -183,14 +279,28 @@ def get_adbc_statement_config(detected_dialect: str) -> StatementConfig:
183279

184280

185281
def _convert_array_for_postgres_adbc(value: Any) -> Any:
186-
"""Convert array values for PostgreSQL compatibility."""
282+
"""Convert array values for PostgreSQL compatibility.
283+
284+
Args:
285+
value: Value to convert
286+
287+
Returns:
288+
Converted value (tuples become lists)
289+
"""
187290
if isinstance(value, tuple):
188291
return list(value)
189292
return value
190293

191294

192295
def get_type_coercion_map(dialect: str) -> "dict[type, Any]":
193-
"""Get type coercion map for Arrow type handling."""
296+
"""Get type coercion map for Arrow type handling.
297+
298+
Args:
299+
dialect: Database dialect name
300+
301+
Returns:
302+
Mapping of Python types to conversion functions
303+
"""
194304
type_map = {
195305
datetime.datetime: lambda x: x,
196306
datetime.date: lambda x: x,
@@ -245,8 +355,6 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
245355
return
246356

247357
try:
248-
from adbc_driver_manager.dbapi import DatabaseError, IntegrityError, OperationalError, ProgrammingError
249-
250358
if issubclass(exc_type, IntegrityError):
251359
e = exc_val
252360
msg = f"Integrity constraint violation: {e}"
@@ -282,9 +390,8 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
282390
class AdbcDriver(SyncDriverAdapterBase):
283391
"""ADBC driver for Arrow Database Connectivity.
284392
285-
Provides database connectivity through ADBC with multi-database dialect
286-
support, Arrow-native data handling with type coercion, parameter style
287-
conversion for different backends, and transaction management.
393+
Provides database connectivity through ADBC with support for multiple
394+
database dialects, parameter style conversion, and transaction management.
288395
"""
289396

290397
__slots__ = ("_detected_dialect", "dialect")
@@ -309,15 +416,26 @@ def __init__(
309416

310417
@staticmethod
311418
def _ensure_pyarrow_installed() -> None:
312-
"""Ensure PyArrow is installed."""
419+
"""Ensure PyArrow is installed.
420+
421+
Raises:
422+
MissingDependencyError: If PyArrow is not installed
423+
"""
313424
from sqlspec.typing import PYARROW_INSTALLED
314425

315426
if not PYARROW_INSTALLED:
316427
raise MissingDependencyError(package="pyarrow", install_package="arrow")
317428

318429
@staticmethod
319430
def _get_dialect(connection: "AdbcConnection") -> str:
320-
"""Detect database dialect from connection information."""
431+
"""Detect database dialect from connection information.
432+
433+
Args:
434+
connection: ADBC connection
435+
436+
Returns:
437+
Detected dialect name (defaults to 'postgres')
438+
"""
321439
try:
322440
driver_info = connection.adbc_get_info()
323441
vendor_name = driver_info.get("vendor_name", "").lower()
@@ -334,31 +452,53 @@ def _get_dialect(connection: "AdbcConnection") -> str:
334452
return "postgres"
335453

336454
def _handle_postgres_rollback(self, cursor: "Cursor") -> None:
337-
"""Execute rollback for PostgreSQL after transaction failure."""
455+
"""Execute rollback for PostgreSQL after transaction failure.
456+
457+
Args:
458+
cursor: Database cursor
459+
"""
338460
if self.dialect == "postgres":
339461
with contextlib.suppress(Exception):
340462
cursor.execute("ROLLBACK")
341463
logger.debug("PostgreSQL rollback executed after transaction failure")
342464

343465
def _handle_postgres_empty_parameters(self, parameters: Any) -> Any:
344-
"""Process empty parameters for PostgreSQL compatibility."""
466+
"""Process empty parameters for PostgreSQL compatibility.
467+
468+
Args:
469+
parameters: Parameter values
470+
471+
Returns:
472+
Processed parameters
473+
"""
345474
if self.dialect == "postgres" and isinstance(parameters, dict) and not parameters:
346475
return None
347476
return parameters
348477

349478
def with_cursor(self, connection: "AdbcConnection") -> "AdbcCursor":
350-
"""Create context manager for cursor."""
479+
"""Create context manager for cursor.
480+
481+
Args:
482+
connection: Database connection
483+
484+
Returns:
485+
Cursor context manager
486+
"""
351487
return AdbcCursor(connection)
352488

353489
def handle_database_exceptions(self) -> "AbstractContextManager[None]":
354-
"""Handle database-specific exceptions and wrap them appropriately."""
490+
"""Handle database-specific exceptions and wrap them appropriately.
491+
492+
Returns:
493+
Exception handler context manager
494+
"""
355495
return AdbcExceptionHandler()
356496

357497
def _try_special_handling(self, cursor: "Cursor", statement: SQL) -> "Optional[SQLResult]":
358498
"""Handle special operations.
359499
360500
Args:
361-
cursor: Cursor object
501+
cursor: Database cursor
362502
statement: SQL statement to analyze
363503
364504
Returns:
@@ -368,7 +508,15 @@ def _try_special_handling(self, cursor: "Cursor", statement: SQL) -> "Optional[S
368508
return None
369509

370510
def _execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
371-
"""Execute SQL with multiple parameter sets."""
511+
"""Execute SQL with multiple parameter sets.
512+
513+
Args:
514+
cursor: Database cursor
515+
statement: SQL statement to execute
516+
517+
Returns:
518+
Execution result with row counts
519+
"""
372520
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
373521

374522
try:
@@ -398,7 +546,15 @@ def _execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
398546
return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
399547

400548
def _execute_statement(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
401-
"""Execute single SQL statement."""
549+
"""Execute single SQL statement.
550+
551+
Args:
552+
cursor: Database cursor
553+
statement: SQL statement to execute
554+
555+
Returns:
556+
Execution result with data or row count
557+
"""
402558
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
403559

404560
try:
@@ -430,7 +586,15 @@ def _execute_statement(self, cursor: "Cursor", statement: SQL) -> "ExecutionResu
430586
return self.create_execution_result(cursor, rowcount_override=row_count)
431587

432588
def _execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResult":
433-
"""Execute SQL script."""
589+
"""Execute SQL script containing multiple statements.
590+
591+
Args:
592+
cursor: Database cursor
593+
statement: SQL script to execute
594+
595+
Returns:
596+
Execution result with statement counts
597+
"""
434598
if statement.is_script:
435599
sql = statement._raw_sql
436600
prepared_parameters: list[Any] = []

0 commit comments

Comments
 (0)