Skip to content

Commit d699b11

Browse files
authored
feat: add Arrow support to PostgreSQL adapters (asyncpg, psycopg, psqlpy) (#157)
Tests & cleanup for Arrow integration
1 parent f17aefe commit d699b11

File tree

6 files changed

+699
-47
lines changed

6 files changed

+699
-47
lines changed

sqlspec/adapters/adbc/config.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,23 +77,12 @@ class AdbcDriverFeatures(TypedDict):
7777
When True, preserves Arrow extension type metadata when reading data.
7878
When False, falls back to storage types.
7979
Default: True
80-
enable_arrow_results: Enable native Arrow query results.
81-
When True, select_to_arrow() uses cursor.fetch_arrow_table() for
82-
zero-copy data transfer (5-10x faster for large datasets).
83-
When False, falls back to dict conversion path.
84-
Default: True
85-
arrow_batch_size: Batch size for Arrow result streaming.
86-
Number of rows per batch when streaming Arrow results.
87-
Used for future streaming implementation.
88-
Default: 1024
8980
"""
9081

9182
json_serializer: "NotRequired[Callable[[Any], str]]"
9283
enable_cast_detection: NotRequired[bool]
9384
strict_type_coercion: NotRequired[bool]
9485
arrow_extension_types: NotRequired[bool]
95-
enable_arrow_results: NotRequired[bool]
96-
arrow_batch_size: NotRequired[int]
9786

9887

9988
__all__ = ("AdbcConfig", "AdbcConnectionParams", "AdbcDriverFeatures")
@@ -158,10 +147,6 @@ def __init__(
158147
driver_features["strict_type_coercion"] = False
159148
if "arrow_extension_types" not in driver_features:
160149
driver_features["arrow_extension_types"] = True
161-
if "enable_arrow_results" not in driver_features:
162-
driver_features["enable_arrow_results"] = True
163-
if "arrow_batch_size" not in driver_features:
164-
driver_features["arrow_batch_size"] = 1024
165150

166151
super().__init__(
167152
connection_config=self.connection_config,

sqlspec/adapters/bigquery/config.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,6 @@ class BigQueryDriverFeatures(TypedDict):
7878
enable_uuid_conversion: Enable automatic UUID string conversion.
7979
When True (default), UUID strings are automatically converted to UUID objects.
8080
When False, UUID strings are treated as regular strings.
81-
enable_arrow_results: Enable native Arrow query results via Storage API.
82-
When True (default), select_to_arrow() uses query_job.to_arrow() with
83-
Storage API for zero-copy data transfer (5-10x faster for large datasets).
84-
Requires google-cloud-bigquery-storage package and API enabled.
85-
Falls back to dict conversion if Storage API unavailable.
86-
Default: True
87-
arrow_batch_size: Batch size for Arrow result streaming.
88-
Number of rows per batch when streaming Arrow results.
89-
Used for future streaming implementation.
90-
Default: 1024
9181
"""
9282

9383
connection_instance: NotRequired["BigQueryConnection"]
@@ -96,8 +86,6 @@ class BigQueryDriverFeatures(TypedDict):
9686
on_connection_create: NotRequired["Callable[[Any], None]"]
9787
json_serializer: NotRequired["Callable[[Any], str]"]
9888
enable_uuid_conversion: NotRequired[bool]
99-
enable_arrow_results: NotRequired[bool]
100-
arrow_batch_size: NotRequired[int]
10189

10290

10391
__all__ = ("BigQueryConfig", "BigQueryConnectionParams", "BigQueryDriverFeatures")
@@ -149,11 +137,6 @@ def __init__(
149137

150138
self.driver_features["json_serializer"] = to_json
151139

152-
if "enable_arrow_results" not in self.driver_features:
153-
self.driver_features["enable_arrow_results"] = True
154-
if "arrow_batch_size" not in self.driver_features:
155-
self.driver_features["arrow_batch_size"] = 1024
156-
157140
self._connection_instance: BigQueryConnection | None = self.driver_features.get("connection_instance")
158141

159142
if "default_query_job_config" not in self.connection_config:

sqlspec/adapters/duckdb/config.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -121,24 +121,13 @@ class DuckDBDriverFeatures(TypedDict):
121121
enable_uuid_conversion: Enable automatic UUID string conversion.
122122
When True (default), UUID strings are automatically converted to UUID objects.
123123
When False, UUID strings are treated as regular strings.
124-
enable_arrow_results: Enable native Arrow query results.
125-
When True (default), select_to_arrow() uses cursor.arrow() for
126-
zero-copy data transfer. DuckDB has the fastest Arrow path due to
127-
its columnar architecture.
128-
Default: True
129-
arrow_batch_size: Batch size for Arrow result streaming.
130-
Number of rows per batch when streaming Arrow results.
131-
Used for future streaming implementation.
132-
Default: 1024
133124
"""
134125

135126
extensions: NotRequired[Sequence[DuckDBExtensionConfig]]
136127
secrets: NotRequired[Sequence[DuckDBSecretConfig]]
137128
on_connection_create: NotRequired["Callable[[DuckDBConnection], DuckDBConnection | None]"]
138129
json_serializer: NotRequired["Callable[[Any], str]"]
139130
enable_uuid_conversion: NotRequired[bool]
140-
enable_arrow_results: NotRequired[bool]
141-
arrow_batch_size: NotRequired[int]
142131

143132

144133
class DuckDBConfig(SyncDatabaseConfig[DuckDBConnection, DuckDBConnectionPool, DuckDBDriver]):
@@ -223,10 +212,6 @@ def __init__(
223212
processed_features = dict(driver_features) if driver_features else {}
224213
if "enable_uuid_conversion" not in processed_features:
225214
processed_features["enable_uuid_conversion"] = True
226-
if "enable_arrow_results" not in processed_features:
227-
processed_features["enable_arrow_results"] = True
228-
if "arrow_batch_size" not in processed_features:
229-
processed_features["arrow_batch_size"] = 1024
230215

231216
super().__init__(
232217
bind_key=bind_key,
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
"""Integration tests for asyncpg Arrow support."""
2+
3+
import pytest
4+
from pytest_databases.docker.postgres import PostgresService
5+
6+
from sqlspec._typing import PYARROW_INSTALLED
7+
from sqlspec.adapters.asyncpg import AsyncpgConfig
8+
9+
pytestmark = [
10+
pytest.mark.xdist_group("postgres"),
11+
pytest.mark.skipif(not PYARROW_INSTALLED, reason="pyarrow not installed"),
12+
]
13+
14+
15+
@pytest.fixture
16+
async def asyncpg_config(postgres_service: PostgresService) -> AsyncpgConfig:
17+
"""Create AsyncPG configuration for testing."""
18+
return AsyncpgConfig(
19+
pool_config={
20+
"dsn": f"postgres://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}",
21+
"min_size": 1,
22+
"max_size": 2,
23+
}
24+
)
25+
26+
27+
async def test_select_to_arrow_basic(asyncpg_config: AsyncpgConfig) -> None:
28+
"""Test basic select_to_arrow functionality."""
29+
import pyarrow as pa
30+
31+
try:
32+
async with asyncpg_config.provide_session() as session:
33+
# Create test table with unique name
34+
await session.execute("DROP TABLE IF EXISTS arrow_users CASCADE")
35+
await session.execute("CREATE TABLE arrow_users (id INTEGER, name TEXT, age INTEGER)")
36+
await session.execute("INSERT INTO arrow_users VALUES (1, 'Alice', 30), (2, 'Bob', 25)")
37+
38+
# Test Arrow query
39+
result = await session.select_to_arrow("SELECT * FROM arrow_users ORDER BY id")
40+
41+
assert result is not None
42+
assert isinstance(result.data, (pa.Table, pa.RecordBatch))
43+
assert result.rows_affected == 2
44+
45+
# Convert to pandas and verify
46+
df = result.to_pandas()
47+
assert len(df) == 2
48+
assert list(df["name"]) == ["Alice", "Bob"]
49+
assert list(df["age"]) == [30, 25]
50+
finally:
51+
await asyncpg_config.close_pool()
52+
53+
54+
async def test_select_to_arrow_table_format(asyncpg_config: AsyncpgConfig) -> None:
55+
"""Test select_to_arrow with table return format (default)."""
56+
import pyarrow as pa
57+
58+
try:
59+
async with asyncpg_config.provide_session() as session:
60+
await session.execute("DROP TABLE IF EXISTS arrow_table_test CASCADE")
61+
await session.execute("CREATE TABLE arrow_table_test (id INTEGER, value TEXT)")
62+
await session.execute("INSERT INTO arrow_table_test VALUES (1, 'a'), (2, 'b'), (3, 'c')")
63+
64+
result = await session.select_to_arrow("SELECT * FROM arrow_table_test ORDER BY id", return_format="table")
65+
66+
assert isinstance(result.data, pa.Table)
67+
assert result.rows_affected == 3
68+
finally:
69+
await asyncpg_config.close_pool()
70+
71+
72+
async def test_select_to_arrow_batch_format(asyncpg_config: AsyncpgConfig) -> None:
73+
"""Test select_to_arrow with batch return format."""
74+
import pyarrow as pa
75+
76+
try:
77+
async with asyncpg_config.provide_session() as session:
78+
await session.execute("DROP TABLE IF EXISTS arrow_batch_test CASCADE")
79+
await session.execute("CREATE TABLE arrow_batch_test (id INTEGER, value TEXT)")
80+
await session.execute("INSERT INTO arrow_batch_test VALUES (1, 'a'), (2, 'b')")
81+
82+
result = await session.select_to_arrow(
83+
"SELECT * FROM arrow_batch_test ORDER BY id", return_format="batches"
84+
)
85+
86+
assert isinstance(result.data, pa.RecordBatch)
87+
assert result.rows_affected == 2
88+
finally:
89+
await asyncpg_config.close_pool()
90+
91+
92+
async def test_select_to_arrow_with_parameters(asyncpg_config: AsyncpgConfig) -> None:
93+
"""Test select_to_arrow with query parameters."""
94+
try:
95+
async with asyncpg_config.provide_session() as session:
96+
await session.execute("DROP TABLE IF EXISTS arrow_params_test CASCADE")
97+
await session.execute("CREATE TABLE arrow_params_test (id INTEGER, value INTEGER)")
98+
await session.execute("INSERT INTO arrow_params_test VALUES (1, 100), (2, 200), (3, 300)")
99+
100+
# Test with parameterized query
101+
result = await session.select_to_arrow("SELECT * FROM arrow_params_test WHERE value > $1 ORDER BY id", 150)
102+
103+
assert result.rows_affected == 2
104+
df = result.to_pandas()
105+
assert list(df["value"]) == [200, 300]
106+
finally:
107+
await asyncpg_config.close_pool()
108+
109+
110+
async def test_select_to_arrow_empty_result(asyncpg_config: AsyncpgConfig) -> None:
111+
"""Test select_to_arrow with empty result set."""
112+
try:
113+
async with asyncpg_config.provide_session() as session:
114+
await session.execute("DROP TABLE IF EXISTS arrow_empty_test CASCADE")
115+
await session.execute("CREATE TABLE arrow_empty_test (id INTEGER)")
116+
117+
result = await session.select_to_arrow("SELECT * FROM arrow_empty_test")
118+
119+
assert result.rows_affected == 0
120+
assert len(result.to_pandas()) == 0
121+
finally:
122+
await asyncpg_config.close_pool()
123+
124+
125+
async def test_select_to_arrow_null_handling(asyncpg_config: AsyncpgConfig) -> None:
126+
"""Test select_to_arrow with NULL values."""
127+
try:
128+
async with asyncpg_config.provide_session() as session:
129+
await session.execute("DROP TABLE IF EXISTS arrow_null_test CASCADE")
130+
await session.execute("CREATE TABLE arrow_null_test (id INTEGER, value TEXT)")
131+
await session.execute("INSERT INTO arrow_null_test VALUES (1, 'a'), (2, NULL), (3, 'c')")
132+
133+
result = await session.select_to_arrow("SELECT * FROM arrow_null_test ORDER BY id")
134+
135+
df = result.to_pandas()
136+
assert len(df) == 3
137+
assert df.iloc[1]["value"] is None or df.isna().iloc[1]["value"]
138+
finally:
139+
await asyncpg_config.close_pool()
140+
141+
142+
async def test_select_to_arrow_to_polars(asyncpg_config: AsyncpgConfig) -> None:
143+
"""Test select_to_arrow conversion to Polars DataFrame."""
144+
pytest.importorskip("polars")
145+
146+
try:
147+
async with asyncpg_config.provide_session() as session:
148+
await session.execute("DROP TABLE IF EXISTS arrow_polars_test CASCADE")
149+
await session.execute("CREATE TABLE arrow_polars_test (id INTEGER, value TEXT)")
150+
await session.execute("INSERT INTO arrow_polars_test VALUES (1, 'a'), (2, 'b')")
151+
152+
result = await session.select_to_arrow("SELECT * FROM arrow_polars_test ORDER BY id")
153+
df = result.to_polars()
154+
155+
assert len(df) == 2
156+
assert df["value"].to_list() == ["a", "b"]
157+
finally:
158+
await asyncpg_config.close_pool()
159+
160+
161+
async def test_select_to_arrow_large_dataset(asyncpg_config: AsyncpgConfig) -> None:
162+
"""Test select_to_arrow with larger dataset."""
163+
try:
164+
async with asyncpg_config.provide_session() as session:
165+
await session.execute("DROP TABLE IF EXISTS arrow_large_test CASCADE")
166+
await session.execute("CREATE TABLE arrow_large_test (id INTEGER, value INTEGER)")
167+
168+
# Insert 1000 rows
169+
values = ", ".join(f"({i}, {i * 10})" for i in range(1, 1001))
170+
await session.execute(f"INSERT INTO arrow_large_test VALUES {values}")
171+
172+
result = await session.select_to_arrow("SELECT * FROM arrow_large_test ORDER BY id")
173+
174+
assert result.rows_affected == 1000
175+
df = result.to_pandas()
176+
assert len(df) == 1000
177+
assert df["value"].sum() == sum(i * 10 for i in range(1, 1001))
178+
finally:
179+
await asyncpg_config.close_pool()
180+
181+
182+
async def test_select_to_arrow_type_preservation(asyncpg_config: AsyncpgConfig) -> None:
183+
"""Test that PostgreSQL types are properly converted to Arrow types."""
184+
try:
185+
async with asyncpg_config.provide_session() as session:
186+
await session.execute("DROP TABLE IF EXISTS arrow_types_test CASCADE")
187+
await session.execute(
188+
"""
189+
CREATE TABLE arrow_types_test (
190+
id INTEGER,
191+
name TEXT,
192+
price NUMERIC,
193+
created_at TIMESTAMP,
194+
is_active BOOLEAN
195+
)
196+
"""
197+
)
198+
await session.execute(
199+
"""
200+
INSERT INTO arrow_types_test VALUES
201+
(1, 'Item 1', 19.99, '2025-01-01 10:00:00', true),
202+
(2, 'Item 2', 29.99, '2025-01-02 15:30:00', false)
203+
"""
204+
)
205+
206+
result = await session.select_to_arrow("SELECT * FROM arrow_types_test ORDER BY id")
207+
208+
df = result.to_pandas()
209+
assert len(df) == 2
210+
assert df["name"].dtype == object
211+
assert df["is_active"].dtype == bool
212+
finally:
213+
await asyncpg_config.close_pool()
214+
215+
216+
async def test_select_to_arrow_postgres_array(asyncpg_config: AsyncpgConfig) -> None:
217+
"""Test PostgreSQL array type handling in Arrow results."""
218+
try:
219+
async with asyncpg_config.provide_session() as session:
220+
await session.execute("DROP TABLE IF EXISTS arrow_array_test CASCADE")
221+
await session.execute("CREATE TABLE arrow_array_test (id INTEGER, tags TEXT[])")
222+
await session.execute(
223+
"INSERT INTO arrow_array_test VALUES (1, ARRAY['python', 'rust']), (2, ARRAY['js', 'ts'])"
224+
)
225+
226+
result = await session.select_to_arrow("SELECT * FROM arrow_array_test ORDER BY id")
227+
228+
# PostgreSQL arrays are returned as Python lists in dict format,
229+
# which Arrow converts to list type
230+
df = result.to_pandas()
231+
assert len(df) == 2
232+
assert isinstance(df["tags"].iloc[0], (list, object))
233+
finally:
234+
await asyncpg_config.close_pool()

0 commit comments

Comments
 (0)