Skip to content

Commit e3701ce

Browse files
authored
feat: implements a select_arrow bulk query method (#22)
Implement a `select_arrow` function that allows you to select directly into an `Arrow` format. Currently implemented for ADBC (Postgres, SQLite, Snowflake, Bigquery, Duckdb, and Postgres), DuckDB, and Oracle (Async and Sync)
1 parent b045798 commit e3701ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+1271
-595
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ clean: ## Cleanup temporary build a
118118
.PHONY: test
119119
test: ## Run the tests
120120
@echo "${INFO} Running test cases... 🧪"
121-
@uv run pytest tests
121+
@uv run pytest -n 2 --dist=loadgroup tests
122122
@echo "${OK} Tests complete ✨"
123123

124124
.PHONY: test-all
@@ -128,7 +128,7 @@ test-all: tests ## Run all tests
128128
.PHONY: coverage
129129
coverage: ## Run tests with coverage report
130130
@echo "${INFO} Running tests with coverage... 📊"
131-
@uv run pytest --cov -n auto --quiet
131+
@uv run pytest --cov -n 2 --dist=loadgroup --quiet
132132
@uv run coverage html >/dev/null 2>&1
133133
@uv run coverage xml >/dev/null 2>&1
134134
@echo "${OK} Coverage report generated ✨"

pyproject.toml

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ exclude_lines = [
175175
]
176176

177177
[tool.pytest.ini_options]
178-
addopts = "-ra -q --doctest-glob='*.md' --strict-markers --strict-config"
178+
addopts = ["-q", "-ra"]
179179
asyncio_default_fixture_loop_scope = "function"
180180
asyncio_mode = "auto"
181181
filterwarnings = [
@@ -189,8 +189,31 @@ filterwarnings = [
189189
"ignore::DeprecationWarning:websockets.connection",
190190
"ignore::DeprecationWarning:websockets.legacy",
191191
]
192+
markers = [
193+
"integration: marks tests that require an external database",
194+
"postgres: marks tests specific to PostgreSQL",
195+
"duckdb: marks tests specific to DuckDB",
196+
"sqlite: marks tests specific to SQLite",
197+
"bigquery: marks tests specific to Google BigQuery",
198+
"mysql: marks tests specific to MySQL",
199+
"oracle: marks tests specific to Oracle",
200+
"spanner: marks tests specific to Google Cloud Spanner",
201+
"mssql: marks tests specific to Microsoft SQL Server",
202+
# Driver markers
203+
"adbc: marks tests using ADBC drivers",
204+
"aioodbc: marks tests using aioodbc",
205+
"aiosqlite: marks tests using aiosqlite",
206+
"asyncmy: marks tests using asyncmy",
207+
"asyncpg: marks tests using asyncpg",
208+
"duckdb_driver: marks tests using the duckdb driver",
209+
"google_bigquery: marks tests using google-cloud-bigquery",
210+
"google_spanner: marks tests using google-cloud-spanner",
211+
"oracledb: marks tests using oracledb",
212+
"psycopg: marks tests using psycopg",
213+
"pymssql: marks tests using pymssql",
214+
"pymysql: marks tests using pymysql",
215+
]
192216
testpaths = ["tests"]
193-
xfail_strict = true
194217

195218
[tool.mypy]
196219
packages = ["sqlspec", "tests"]
@@ -220,6 +243,8 @@ module = [
220243
"uvloop.*",
221244
"asyncmy",
222245
"asyncmy.*",
246+
"pyarrow",
247+
"pyarrow.*",
223248
]
224249

225250
[tool.pyright]

sqlspec/_typing.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
# ruff: noqa: RUF100, PLR0913, A002, DOC201, PLR6301
12
"""This is a simple wrapper around a few important classes in each library.
23
34
This is used to ensure compatibility when one or more of the libraries are installed.
45
"""
56

7+
from collections.abc import Iterable, Mapping
68
from enum import Enum
79
from typing import (
810
Any,
@@ -96,7 +98,7 @@ def __init__(
9698

9799
def validate_python(
98100
self,
99-
object: Any, # noqa: A002
101+
object: Any,
100102
/,
101103
*,
102104
strict: "Optional[bool]" = None,
@@ -127,10 +129,7 @@ class FailFast: # type: ignore[no-redef]
127129
except ImportError:
128130
import enum
129131
from collections.abc import Iterable
130-
from typing import TYPE_CHECKING, Callable, Optional, Union
131-
132-
if TYPE_CHECKING:
133-
from collections.abc import Iterable
132+
from typing import Callable, Optional, Union
134133

135134
@dataclass_transform()
136135
@runtime_checkable
@@ -174,7 +173,6 @@ def __init__(self, backend: Any, data_as_builtins: Any) -> None:
174173
"""Placeholder init"""
175174

176175
def create_instance(self, **kwargs: Any) -> "T":
177-
"""Placeholder implementation"""
178176
return cast("T", kwargs)
179177

180178
def update_instance(self, instance: "T", **kwargs: Any) -> "T":
@@ -198,11 +196,46 @@ class EmptyEnum(Enum):
198196
Empty: Final = EmptyEnum.EMPTY
199197

200198

199+
try:
200+
from pyarrow import Table as ArrowTable
201+
202+
PYARROW_INSTALLED = True
203+
except ImportError:
204+
205+
@runtime_checkable
206+
class ArrowTable(Protocol): # type: ignore[no-redef]
207+
"""Placeholder Implementation"""
208+
209+
def to_batches(self, batch_size: int) -> Any: ...
210+
def num_rows(self) -> int: ...
211+
def num_columns(self) -> int: ...
212+
def to_pydict(self) -> dict[str, Any]: ...
213+
def to_string(self) -> str: ...
214+
def from_arrays(
215+
self,
216+
arrays: list[Any],
217+
names: "Optional[list[str]]" = None,
218+
schema: "Optional[Any]" = None,
219+
metadata: "Optional[Mapping[str, Any]]" = None,
220+
) -> Any: ...
221+
def from_pydict(
222+
self,
223+
mapping: dict[str, Any],
224+
schema: "Optional[Any]" = None,
225+
metadata: "Optional[Mapping[str, Any]]" = None,
226+
) -> Any: ...
227+
def from_batches(self, batches: Iterable[Any], schema: Optional[Any] = None) -> Any: ...
228+
229+
PYARROW_INSTALLED = False # pyright: ignore[reportConstantRedefinition]
230+
231+
201232
__all__ = (
202233
"LITESTAR_INSTALLED",
203234
"MSGSPEC_INSTALLED",
235+
"PYARROW_INSTALLED",
204236
"PYDANTIC_INSTALLED",
205237
"UNSET",
238+
"ArrowTable",
206239
"BaseModel",
207240
"DTOData",
208241
"DataclassProtocol",

sqlspec/adapters/adbc/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from sqlspec.adapters.adbc.config import Adbc
1+
from sqlspec.adapters.adbc.config import AdbcConfig
22
from sqlspec.adapters.adbc.driver import AdbcDriver
33

44
__all__ = (
5-
"Adbc",
5+
"AdbcConfig",
66
"AdbcDriver",
77
)

sqlspec/adapters/adbc/config.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from collections.abc import Generator
1515

1616

17-
__all__ = ("Adbc",)
17+
__all__ = ("AdbcConfig",)
1818

1919

2020
@dataclass
21-
class Adbc(NoPoolSyncConfig["Connection", "AdbcDriver"]):
21+
class AdbcConfig(NoPoolSyncConfig["Connection", "AdbcDriver"]):
2222
"""Configuration for ADBC connections.
2323
2424
This class provides configuration options for ADBC database connections using the
@@ -55,17 +55,41 @@ def _set_adbc(self) -> str: # noqa: PLR0912
5555
"""
5656

5757
if isinstance(self.driver_name, str):
58-
if self.driver_name != "adbc_driver_sqlite.dbapi.connect" and "sqlite" in self.driver_name:
58+
if self.driver_name != "adbc_driver_sqlite.dbapi.connect" and self.driver_name in {
59+
"sqlite",
60+
"sqlite3",
61+
"adbc_driver_sqlite",
62+
}:
5963
self.driver_name = "adbc_driver_sqlite.dbapi.connect"
60-
elif self.driver_name != "adbc_driver_duckdb.dbapi.connect" and "duckdb" in self.driver_name:
64+
elif self.driver_name != "adbc_driver_duckdb.dbapi.connect" and self.driver_name in {
65+
"duckdb",
66+
"adbc_driver_duckdb",
67+
}:
6168
self.driver_name = "adbc_driver_duckdb.dbapi.connect"
62-
elif self.driver_name != "adbc_driver_postgresql.dbapi.connect" and "postgres" in self.driver_name:
69+
elif self.driver_name != "adbc_driver_postgresql.dbapi.connect" and self.driver_name in {
70+
"postgres",
71+
"adbc_driver_postgresql",
72+
"postgresql",
73+
"pg",
74+
}:
6375
self.driver_name = "adbc_driver_postgresql.dbapi.connect"
64-
elif self.driver_name != "adbc_driver_snowflake.dbapi.connect" and "snowflake" in self.driver_name:
76+
elif self.driver_name != "adbc_driver_snowflake.dbapi.connect" and self.driver_name in {
77+
"snowflake",
78+
"adbc_driver_snowflake",
79+
"sf",
80+
}:
6581
self.driver_name = "adbc_driver_snowflake.dbapi.connect"
66-
elif self.driver_name != "adbc_driver_bigquery.dbapi.connect" and "bigquery" in self.driver_name:
82+
elif self.driver_name != "adbc_driver_bigquery.dbapi.connect" and self.driver_name in {
83+
"bigquery",
84+
"adbc_driver_bigquery",
85+
"bq",
86+
}:
6787
self.driver_name = "adbc_driver_bigquery.dbapi.connect"
68-
elif self.driver_name != "adbc_driver_flightsql.dbapi.connect" and "flightsql" in self.driver_name:
88+
elif self.driver_name != "adbc_driver_flightsql.dbapi.connect" and self.driver_name in {
89+
"flightsql",
90+
"adbc_driver_flightsql",
91+
"grpc",
92+
}:
6993
self.driver_name = "adbc_driver_flightsql.dbapi.connect"
7094
return self.driver_name
7195

@@ -153,11 +177,10 @@ def create_connection(self) -> "Connection":
153177
"""
154178
try:
155179
connect_func = self._get_connect_func()
156-
_config = self.connection_config_dict
157-
return connect_func(**_config)
180+
return connect_func(**self.connection_config_dict)
158181
except Exception as e:
159182
# Include driver name in error message for better context
160-
driver_name = self.driver_name if isinstance(self.driver_name, str) else "Unknown/Derived"
183+
driver_name = self.driver_name if isinstance(self.driver_name, str) else "Unknown/Missing"
161184
# Use the potentially modified driver_path from _get_connect_func if available,
162185
# otherwise fallback to self.driver_name for the error message.
163186
# This requires _get_connect_func to potentially return the used path or store it.

sqlspec/adapters/adbc/driver.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
import re
33
from collections.abc import Generator
44
from contextlib import contextmanager
5-
from typing import TYPE_CHECKING, Any, Optional, Union, cast
5+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
66

7-
from adbc_driver_manager.dbapi import Connection, Cursor
7+
from adbc_driver_manager.dbapi import Connection
8+
from adbc_driver_manager.dbapi import Cursor as DbapiCursor
89

9-
from sqlspec.base import SyncDriverAdapterProtocol, T
10+
from sqlspec._typing import ArrowTable
11+
from sqlspec.base import SyncArrowBulkOperationsMixin, SyncDriverAdapterProtocol, T
1012

1113
if TYPE_CHECKING:
12-
from sqlspec.typing import ModelDTOT, StatementParameterType
14+
from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType
1315

1416
__all__ = ("AdbcDriver",)
1517

@@ -26,10 +28,11 @@
2628
)
2729

2830

29-
class AdbcDriver(SyncDriverAdapterProtocol["Connection"]):
31+
class AdbcDriver(SyncArrowBulkOperationsMixin["Connection"], SyncDriverAdapterProtocol["Connection"]):
3032
"""ADBC Sync Driver Adapter."""
3133

3234
connection: Connection
35+
__supports_arrow__: ClassVar[bool] = True
3336

3437
def __init__(self, connection: "Connection") -> None:
3538
"""Initialize the ADBC driver adapter."""
@@ -38,12 +41,12 @@ def __init__(self, connection: "Connection") -> None:
3841
# For now, assume 'qmark' based on typical ADBC DBAPI behavior
3942

4043
@staticmethod
41-
def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor":
44+
def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "DbapiCursor":
4245
return connection.cursor(*args, **kwargs)
4346

4447
@contextmanager
45-
def _with_cursor(self, connection: "Connection") -> Generator["Cursor", None, None]:
46-
cursor = self._cursor(connection)
48+
def _with_cursor(self, connection: "Connection") -> Generator["DbapiCursor", None, None]:
49+
cursor: DbapiCursor = self._cursor(connection)
4750
try:
4851
yield cursor
4952
finally:
@@ -331,3 +334,24 @@ def execute_script_returning(
331334
if schema_type is not None:
332335
return cast("ModelDTOT", schema_type(**dict(zip(column_names, result[0])))) # pyright: ignore[reportUnknownArgumentType]
333336
return dict(zip(column_names, result[0])) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
337+
338+
# --- Arrow Bulk Operations ---
339+
340+
def select_arrow( # pyright: ignore[reportUnknownParameterType]
341+
self,
342+
sql: str,
343+
parameters: "Optional[StatementParameterType]" = None,
344+
/,
345+
connection: "Optional[Connection]" = None,
346+
) -> "ArrowTable":
347+
"""Execute a SQL query and return results as an Apache Arrow Table.
348+
349+
Returns:
350+
The results of the query as an Apache Arrow Table.
351+
"""
352+
conn = self._connection(connection)
353+
sql, parameters = self._process_sql_params(sql, parameters)
354+
355+
with self._with_cursor(conn) as cursor:
356+
cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType]
357+
return cast("ArrowTable", cursor.fetch_arrow_table()) # pyright: ignore[reportUnknownMemberType]
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from sqlspec.adapters.aiosqlite.config import Aiosqlite
1+
from sqlspec.adapters.aiosqlite.config import AiosqliteConfig
22
from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver
33

44
__all__ = (
5-
"Aiosqlite",
5+
"AiosqliteConfig",
66
"AiosqliteDriver",
77
)

sqlspec/adapters/aiosqlite/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from typing import Literal
1616

1717

18-
__all__ = ("Aiosqlite",)
18+
__all__ = ("AiosqliteConfig",)
1919

2020

2121
@dataclass
22-
class Aiosqlite(NoPoolAsyncConfig["Connection", "AiosqliteDriver"]):
22+
class AiosqliteConfig(NoPoolAsyncConfig["Connection", "AiosqliteDriver"]):
2323
"""Configuration for Aiosqlite database connections.
2424
2525
This class provides configuration options for Aiosqlite database connections, wrapping all parameters
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from sqlspec.adapters.asyncmy.config import Asyncmy, AsyncmyPool
1+
from sqlspec.adapters.asyncmy.config import AsyncmyConfig, AsyncmyPoolConfig
22
from sqlspec.adapters.asyncmy.driver import AsyncmyDriver # type: ignore[attr-defined]
33

44
__all__ = (
5-
"Asyncmy",
5+
"AsyncmyConfig",
66
"AsyncmyDriver",
7-
"AsyncmyPool",
7+
"AsyncmyPoolConfig",
88
)

sqlspec/adapters/asyncmy/config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@
1616
from asyncmy.pool import Pool # pyright: ignore[reportUnknownVariableType]
1717

1818
__all__ = (
19-
"Asyncmy",
20-
"AsyncmyPool",
19+
"AsyncmyConfig",
20+
"AsyncmyPoolConfig",
2121
)
2222

2323

2424
T = TypeVar("T")
2525

2626

2727
@dataclass
28-
class AsyncmyPool(GenericPoolConfig):
28+
class AsyncmyPoolConfig(GenericPoolConfig):
2929
"""Configuration for Asyncmy's connection pool.
3030
3131
This class provides configuration options for Asyncmy database connection pools.
@@ -104,19 +104,19 @@ def pool_config_dict(self) -> "dict[str, Any]":
104104

105105

106106
@dataclass
107-
class Asyncmy(AsyncDatabaseConfig["Connection", "Pool", "AsyncmyDriver"]):
107+
class AsyncmyConfig(AsyncDatabaseConfig["Connection", "Pool", "AsyncmyDriver"]):
108108
"""Asyncmy Configuration."""
109109

110110
__is_async__ = True
111111
__supports_connection_pooling__ = True
112112

113-
pool_config: "Optional[AsyncmyPool]" = None
113+
pool_config: "Optional[AsyncmyPoolConfig]" = None
114114
"""Asyncmy Pool configuration"""
115-
connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection) # pyright: ignore
115+
connection_type: "type[Connection]" = field(hash=False, init=False, default_factory=lambda: Connection) # pyright: ignore
116116
"""Type of the connection object"""
117-
driver_type: "type[AsyncmyDriver]" = field(init=False, default_factory=lambda: AsyncmyDriver)
117+
driver_type: "type[AsyncmyDriver]" = field(hash=False, init=False, default_factory=lambda: AsyncmyDriver)
118118
"""Type of the driver object"""
119-
pool_instance: "Optional[Pool]" = None # pyright: ignore[reportUnknownVariableType]
119+
pool_instance: "Optional[Pool]" = field(hash=False, default=None) # pyright: ignore[reportUnknownVariableType]
120120
"""Instance of the pool"""
121121

122122
@property

0 commit comments

Comments
 (0)