22import re
33from collections .abc import Generator
44from contextlib import contextmanager
5- from typing import TYPE_CHECKING , Any , Optional , Union , cast
5+ from typing import TYPE_CHECKING , Any , ClassVar , Optional , Union , cast
66
7- from adbc_driver_manager .dbapi import Connection , Cursor
7+ from adbc_driver_manager .dbapi import Connection
8+ from adbc_driver_manager .dbapi import Cursor as DbapiCursor
89
9- from sqlspec .base import SyncDriverAdapterProtocol , T
10+ from sqlspec ._typing import ArrowTable
11+ from sqlspec .base import SyncArrowBulkOperationsMixin , SyncDriverAdapterProtocol , T
1012
1113if TYPE_CHECKING :
12- from sqlspec .typing import ModelDTOT , StatementParameterType
14+ from sqlspec .typing import ArrowTable , ModelDTOT , StatementParameterType
1315
1416__all__ = ("AdbcDriver" ,)
1517
2628)
2729
2830
29- class AdbcDriver (SyncDriverAdapterProtocol ["Connection" ]):
31+ class AdbcDriver (SyncArrowBulkOperationsMixin [ "Connection" ], SyncDriverAdapterProtocol ["Connection" ]):
3032 """ADBC Sync Driver Adapter."""
3133
3234 connection : Connection
35+ __supports_arrow__ : ClassVar [bool ] = True
3336
3437 def __init__ (self , connection : "Connection" ) -> None :
3538 """Initialize the ADBC driver adapter."""
@@ -38,12 +41,12 @@ def __init__(self, connection: "Connection") -> None:
3841 # For now, assume 'qmark' based on typical ADBC DBAPI behavior
3942
4043 @staticmethod
41- def _cursor (connection : "Connection" , * args : Any , ** kwargs : Any ) -> "Cursor " :
44+ def _cursor (connection : "Connection" , * args : Any , ** kwargs : Any ) -> "DbapiCursor " :
4245 return connection .cursor (* args , ** kwargs )
4346
4447 @contextmanager
45- def _with_cursor (self , connection : "Connection" ) -> Generator ["Cursor " , None , None ]:
46- cursor = self ._cursor (connection )
48+ def _with_cursor (self , connection : "Connection" ) -> Generator ["DbapiCursor " , None , None ]:
49+ cursor : DbapiCursor = self ._cursor (connection )
4750 try :
4851 yield cursor
4952 finally :
@@ -331,3 +334,24 @@ def execute_script_returning(
331334 if schema_type is not None :
332335 return cast ("ModelDTOT" , schema_type (** dict (zip (column_names , result [0 ])))) # pyright: ignore[reportUnknownArgumentType]
333336 return dict (zip (column_names , result [0 ])) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
337+
338+ # --- Arrow Bulk Operations ---
339+
340+ def select_arrow ( # pyright: ignore[reportUnknownParameterType]
341+ self ,
342+ sql : str ,
343+ parameters : "Optional[StatementParameterType]" = None ,
344+ / ,
345+ connection : "Optional[Connection]" = None ,
346+ ) -> "ArrowTable" :
347+ """Execute a SQL query and return results as an Apache Arrow Table.
348+
349+ Returns:
350+ The results of the query as an Apache Arrow Table.
351+ """
352+ conn = self ._connection (connection )
353+ sql , parameters = self ._process_sql_params (sql , parameters )
354+
355+ with self ._with_cursor (conn ) as cursor :
356+ cursor .execute (sql , parameters ) # pyright: ignore[reportUnknownMemberType]
357+ return cast ("ArrowTable" , cursor .fetch_arrow_table ()) # pyright: ignore[reportUnknownMemberType]
0 commit comments