Skip to content

Commit 37cda80

Browse files
test: Ensure proper async connection cleanup on DB test exit (pola-rs#25766)
1 parent ac0b751 commit 37cda80

File tree

2 files changed

+69
-42
lines changed

2 files changed

+69
-42
lines changed

py-polars/src/polars/io/database/_executor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,16 @@ async def _sqlalchemy_async_execute(self, query: TextClause, **options: Any) ->
456456
"""Execute a query using an async SQLAlchemy connection."""
457457
is_session = self._is_alchemy_session(self.cursor)
458458
cursor = self.cursor.begin() if is_session else self.cursor # type: ignore[attr-defined]
459+
460+
# check if connection is already started (eg: user awaited `engine.connect()`);
461+
# if so, use it directly without entering the context manager again
462+
try:
463+
if object.__getattribute__(cursor, "sync_connection") is not None:
464+
result = await cursor.execute(query, **options)
465+
return result
466+
except AttributeError:
467+
pass
468+
459469
async with cursor as conn: # type: ignore[union-attr]
460470
if is_session and not hasattr(conn, "execute"):
461471
conn = conn.session

py-polars/tests/unit/io/database/test_async.py

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections.abc import Iterable
2020
from pathlib import Path
2121

22+
2223
SURREAL_MOCK_DATA: list[dict[str, Any]] = [
2324
{
2425
"id": "item:8xj31jfpdkf9gvmxdxpi",
@@ -84,58 +85,74 @@ class MockedSurrealModule(ModuleType):
8485
)
8586
def test_read_async(tmp_sqlite_db: Path) -> None:
8687
# confirm that we can load frame data from the core sqlalchemy async
87-
# primitives: AsyncConnection, AsyncEngine, and async_sessionmaker
88+
# primitives: AsyncEngine, AsyncConnection, async_sessionmaker, and AsyncSession
8889
from sqlalchemy.ext.asyncio import async_sessionmaker
8990

90-
async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")
91-
async_connection = async_engine.connect()
92-
async_session = async_sessionmaker(async_engine)
93-
async_session_inst = async_session()
94-
95-
expected_frame = pl.DataFrame(
96-
{"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]}
97-
)
98-
async_conn: Any
99-
for async_conn in (
100-
async_engine,
101-
async_connection,
102-
async_session,
103-
async_session_inst,
104-
):
105-
if async_conn in (async_session, async_session_inst):
106-
constraint, execute_opts = "", {}
107-
else:
108-
constraint = "WHERE value > :n"
109-
execute_opts = {"parameters": {"n": -1000}}
110-
111-
df = pl.read_database(
112-
query=f"""
113-
SELECT id, name, value
114-
FROM test_data {constraint}
115-
ORDER BY id DESC
116-
""",
117-
connection=async_conn,
118-
execute_options=execute_opts,
119-
)
120-
assert_frame_equal(expected_frame, df)
91+
async def _test_impl() -> None:
92+
async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")
93+
async_connection = await async_engine.connect()
94+
try:
95+
async_session = async_sessionmaker(async_engine)
96+
async_session_inst = async_session()
12197

98+
expected_frame = pl.DataFrame(
99+
{"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]}
100+
)
101+
async_conn: Any
102+
for async_conn in (
103+
async_engine,
104+
async_connection,
105+
async_session,
106+
async_session_inst,
107+
):
108+
if async_conn in (async_session, async_session_inst):
109+
constraint, execute_opts = "", {}
110+
else:
111+
constraint = "WHERE value > :n"
112+
execute_opts = {"parameters": {"n": -1000}}
113+
114+
df = pl.read_database(
115+
query=f"""
116+
SELECT id, name, value
117+
FROM test_data {constraint}
118+
ORDER BY id DESC
119+
""",
120+
connection=async_conn,
121+
execute_options=execute_opts,
122+
)
123+
assert_frame_equal(expected_frame, df)
124+
finally:
125+
await async_session_inst.close()
126+
await async_connection.close()
127+
await async_engine.dispose()
122128

123-
async def _nested_async_test(tmp_sqlite_db: Path) -> pl.DataFrame:
124-
async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")
125-
return pl.read_database(
126-
query="SELECT id, name FROM test_data ORDER BY id",
127-
connection=async_engine.connect(),
128-
)
129+
asyncio.run(_test_impl())
129130

130131

131132
@pytest.mark.skipif(
132133
parse_version(sqlalchemy.__version__) < (2, 0),
133134
reason="SQLAlchemy 2.0+ required for async tests",
134135
)
135-
def test_read_async_nested(tmp_sqlite_db: Path) -> None:
136-
# This tests validates that we can handle nested async calls
136+
@pytest.mark.parametrize("started", [True, False])
137+
def test_read_async_nested(tmp_sqlite_db: Path, started: bool) -> None:
138+
# validate that we can handle nested async calls; check
139+
# this works with connections that are started/unstarted
140+
async def _test_impl() -> pl.DataFrame:
141+
async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")
142+
async_connection = async_engine.connect()
143+
if started:
144+
async_connection = await async_connection
145+
try:
146+
return pl.read_database(
147+
query="SELECT id, name FROM test_data ORDER BY id",
148+
connection=async_connection,
149+
)
150+
finally:
151+
await async_connection.close()
152+
await async_engine.dispose()
153+
137154
expected_frame = pl.DataFrame({"id": [1, 2], "name": ["misc", "other"]})
138-
df = asyncio.run(_nested_async_test(tmp_sqlite_db))
155+
df = asyncio.run(_test_impl())
139156
assert_frame_equal(expected_frame, df)
140157

141158

@@ -187,7 +204,7 @@ def test_surrealdb_fetchall(batch_size: int | None) -> None:
187204

188205

189206
def test_async_nested_captured_loop_21263() -> None:
190-
# Tests awaiting a future that has "captured" the original event loop from
207+
# tests awaiting a future that has "captured" the original event loop from
191208
# within a `_run_async` context.
192209
async def test_impl() -> None:
193210
loop = asyncio.get_running_loop()

0 commit comments

Comments
 (0)