|
1 | 1 | # pyright: reportCallIssue=false, reportAttributeAccessIssue=false, reportArgumentType=false |
| 2 | +import logging |
2 | 3 | from typing import TYPE_CHECKING, Any, Optional |
3 | 4 |
|
4 | | -if TYPE_CHECKING: |
5 | | - from sqlspec.driver._common import ExecutionResult |
6 | | - from sqlspec.statement.result import SQLResult |
7 | | - from sqlspec.statement.sql import SQL |
8 | | - |
9 | 5 | from sqlspec.adapters.psycopg._types import PsycopgAsyncConnection, PsycopgSyncConnection |
10 | 6 | from sqlspec.adapters.psycopg.mixins import PsycopgCopyMixin |
11 | 7 | from sqlspec.adapters.psycopg.pipeline_steps import postgres_copy_pipeline_step |
|
14 | 10 | from sqlspec.parameters.config import ParameterStyleConfig |
15 | 11 | from sqlspec.statement.sql import SQL, StatementConfig |
16 | 12 |
|
| 13 | +if TYPE_CHECKING: |
| 14 | + from sqlspec.driver._common import ExecutionResult |
| 15 | + from sqlspec.statement.result import SQLResult |
| 16 | + from sqlspec.statement.sql import SQL |
| 17 | + |
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
| 20 | + |
17 | 21 | psycopg_statement_config = StatementConfig( |
18 | 22 | dialect="postgres", |
19 | 23 | parameter_config=ParameterStyleConfig( |
|
22 | 26 | ParameterStyle.POSITIONAL_PYFORMAT, |
23 | 27 | ParameterStyle.NAMED_PYFORMAT, |
24 | 28 | ParameterStyle.NUMERIC, |
| 29 | + ParameterStyle.QMARK, # Add support for ? placeholders |
25 | 30 | }, |
26 | 31 | execution_parameter_style=ParameterStyle.POSITIONAL_PYFORMAT, # Convert all to positional for reliability |
27 | 32 | type_coercion_map={}, |
@@ -102,24 +107,64 @@ def _try_special_handling(self, cursor: Any, statement: "SQL") -> "Optional[SQLR |
102 | 107 | None if standard execution should proceed |
103 | 108 | """ |
104 | 109 | # Check if this is a COPY statement marked by the pipeline |
105 | | - if statement._processing_context and statement._processing_context.metadata.get("postgres_copy_operation"): |
106 | | - result = self._handle_copy_operation_from_pipeline(cursor, statement) |
107 | 110 |
|
108 | | - # Create ExecutionResult and build SQLResult directly |
109 | | - execution_result = self.create_execution_result(result) |
110 | | - return self.build_statement_result(statement, execution_result) |
| 111 | + if statement._processing_context and statement._processing_context.metadata.get("postgres_copy_operation"): |
| 112 | + try: |
| 113 | + result = self._handle_copy_operation_from_pipeline(cursor, statement) |
| 114 | + |
| 115 | + # Create ExecutionResult and build SQLResult directly - capture rowcount from cursor |
| 116 | + row_count = result.rowcount if hasattr(result, "rowcount") and result.rowcount is not None else -1 |
| 117 | + execution_result = self.create_execution_result(result, rowcount_override=row_count) |
| 118 | + return self.build_statement_result(statement, execution_result) |
| 119 | + except Exception: |
| 120 | + # Log the error but don't fail silently |
| 121 | + logger.exception("COPY operation failed in special handling") |
| 122 | + raise |
111 | 123 |
|
112 | 124 | return None |
113 | 125 |
|
114 | 126 | def _execute_copy_with_data(self, cursor: Any, sql_text: str, data_str: str) -> Any: |
115 | 127 | """Execute COPY operation with data using Psycopg sync context manager.""" |
116 | 128 | with cursor.copy(sql_text) as copy: |
117 | 129 | copy.write(data_str) |
| 130 | + # Return cursor after copy operation completes - rowcount should be available |
| 131 | + return cursor |
118 | 132 |
|
119 | 133 | def _execute_copy_without_data(self, cursor: Any, sql_text: str) -> Any: |
120 | 134 | """Execute COPY operation without data using Psycopg sync context manager.""" |
121 | 135 | with cursor.copy(sql_text): |
122 | 136 | pass # Just execute the COPY command |
| 137 | + return cursor |
| 138 | + |
| 139 | + def _handle_copy_operation_from_pipeline(self, cursor: Any, statement: "SQL") -> Any: |
| 140 | + """Sync version of COPY handling using the pipeline metadata.""" |
| 141 | + # Get the original SQL from pipeline metadata |
| 142 | + metadata = statement._processing_context.metadata if statement._processing_context else {} |
| 143 | + sql_text = metadata.get("postgres_copy_original_sql") |
| 144 | + if not sql_text: |
| 145 | + # Fallback to expression |
| 146 | + sql_text = str(statement.expression) |
| 147 | + |
| 148 | + # Get the raw COPY data from pipeline metadata |
| 149 | + copy_data = metadata.get("postgres_copy_data") |
| 150 | + |
| 151 | + if copy_data: |
| 152 | + # Handle different parameter formats (positional or keyword) |
| 153 | + if isinstance(copy_data, dict): |
| 154 | + # For named parameters, assume single data value or concatenate all values |
| 155 | + if len(copy_data) == 1: |
| 156 | + data_str = str(next(iter(copy_data.values()))) |
| 157 | + else: |
| 158 | + data_str = "\n".join(str(value) for value in copy_data.values()) |
| 159 | + elif isinstance(copy_data, (list, tuple)): |
| 160 | + # For positional parameters, if single item, use as is, otherwise join |
| 161 | + data_str = str(copy_data[0]) if len(copy_data) == 1 else "\n".join(str(value) for value in copy_data) |
| 162 | + else: |
| 163 | + data_str = str(copy_data) |
| 164 | + |
| 165 | + return self._execute_copy_with_data(cursor, sql_text, data_str) |
| 166 | + # COPY without data (e.g., COPY TO STDOUT) |
| 167 | + return self._execute_copy_without_data(cursor, sql_text) |
123 | 168 |
|
124 | 169 | def _execute_script( |
125 | 170 | self, cursor: Any, sql: str, prepared_params: Any, statement_config: "StatementConfig", statement: "SQL" |
@@ -215,24 +260,34 @@ async def _try_special_handling(self, cursor: Any, statement: "SQL") -> "Optiona |
215 | 260 | None if standard execution should proceed |
216 | 261 | """ |
217 | 262 | # Check if this is a COPY statement marked by the pipeline |
218 | | - if statement._processing_context and statement._processing_context.metadata.get("postgres_copy_operation"): |
219 | | - result = await self._handle_copy_operation_from_pipeline(cursor, statement) |
220 | 263 |
|
221 | | - # Create ExecutionResult and build SQLResult directly |
222 | | - execution_result = self.create_execution_result(result) |
223 | | - return self.build_statement_result(statement, execution_result) |
| 264 | + if statement._processing_context and statement._processing_context.metadata.get("postgres_copy_operation"): |
| 265 | + try: |
| 266 | + result = await self._handle_copy_operation_from_pipeline(cursor, statement) |
| 267 | + |
| 268 | + # Create ExecutionResult and build SQLResult directly - capture rowcount from cursor |
| 269 | + row_count = result.rowcount if hasattr(result, "rowcount") and result.rowcount is not None else -1 |
| 270 | + execution_result = self.create_execution_result(result, rowcount_override=row_count) |
| 271 | + return self.build_statement_result(statement, execution_result) |
| 272 | + except Exception: |
| 273 | + # Log the error but don't fail silently |
| 274 | + logger.exception("Async COPY operation failed in special handling") |
| 275 | + raise |
224 | 276 |
|
225 | 277 | return None |
226 | 278 |
|
227 | 279 | async def _execute_copy_with_data(self, cursor: Any, sql_text: str, data_str: str) -> Any: |
228 | 280 | """Execute COPY operation with data using Psycopg async context manager.""" |
229 | 281 | async with cursor.copy(sql_text) as copy: |
230 | 282 | await copy.write(data_str) |
| 283 | + # Return cursor after copy operation completes - rowcount should be available |
| 284 | + return cursor |
231 | 285 |
|
232 | 286 | async def _execute_copy_without_data(self, cursor: Any, sql_text: str) -> Any: |
233 | 287 | """Execute COPY operation without data using Psycopg async context manager.""" |
234 | 288 | async with cursor.copy(sql_text): |
235 | 289 | pass # Just execute the COPY command |
| 290 | + return cursor |
236 | 291 |
|
237 | 292 | async def _handle_copy_operation_from_pipeline(self, cursor: Any, statement: "SQL") -> Any: |
238 | 293 | """Async version of COPY handling using the mixin logic.""" |
|
0 commit comments