|
19 | 19 | from collections.abc import Iterable |
20 | 20 | from pathlib import Path |
21 | 21 |
|
| 22 | + |
22 | 23 | SURREAL_MOCK_DATA: list[dict[str, Any]] = [ |
23 | 24 | { |
24 | 25 | "id": "item:8xj31jfpdkf9gvmxdxpi", |
@@ -84,58 +85,74 @@ class MockedSurrealModule(ModuleType): |
84 | 85 | ) |
85 | 86 | def test_read_async(tmp_sqlite_db: Path) -> None: |
86 | 87 | # confirm that we can load frame data from the core sqlalchemy async |
87 | | - # primitives: AsyncConnection, AsyncEngine, and async_sessionmaker |
| 88 | + # primitives: AsyncEngine, AsyncConnection, async_sessionmaker, and AsyncSession |
88 | 89 | from sqlalchemy.ext.asyncio import async_sessionmaker |
89 | 90 |
|
90 | | - async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}") |
91 | | - async_connection = async_engine.connect() |
92 | | - async_session = async_sessionmaker(async_engine) |
93 | | - async_session_inst = async_session() |
94 | | - |
95 | | - expected_frame = pl.DataFrame( |
96 | | - {"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]} |
97 | | - ) |
98 | | - async_conn: Any |
99 | | - for async_conn in ( |
100 | | - async_engine, |
101 | | - async_connection, |
102 | | - async_session, |
103 | | - async_session_inst, |
104 | | - ): |
105 | | - if async_conn in (async_session, async_session_inst): |
106 | | - constraint, execute_opts = "", {} |
107 | | - else: |
108 | | - constraint = "WHERE value > :n" |
109 | | - execute_opts = {"parameters": {"n": -1000}} |
110 | | - |
111 | | - df = pl.read_database( |
112 | | - query=f""" |
113 | | - SELECT id, name, value |
114 | | - FROM test_data {constraint} |
115 | | - ORDER BY id DESC |
116 | | - """, |
117 | | - connection=async_conn, |
118 | | - execute_options=execute_opts, |
119 | | - ) |
120 | | - assert_frame_equal(expected_frame, df) |
| 91 | + async def _test_impl() -> None: |
| 92 | + async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}") |
| 93 | + async_connection = await async_engine.connect() |
| 94 | + try: |
| 95 | + async_session = async_sessionmaker(async_engine) |
| 96 | + async_session_inst = async_session() |
121 | 97 |
|
| 98 | + expected_frame = pl.DataFrame( |
| 99 | + {"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]} |
| 100 | + ) |
| 101 | + async_conn: Any |
| 102 | + for async_conn in ( |
| 103 | + async_engine, |
| 104 | + async_connection, |
| 105 | + async_session, |
| 106 | + async_session_inst, |
| 107 | + ): |
| 108 | + if async_conn in (async_session, async_session_inst): |
| 109 | + constraint, execute_opts = "", {} |
| 110 | + else: |
| 111 | + constraint = "WHERE value > :n" |
| 112 | + execute_opts = {"parameters": {"n": -1000}} |
| 113 | + |
| 114 | + df = pl.read_database( |
| 115 | + query=f""" |
| 116 | + SELECT id, name, value |
| 117 | + FROM test_data {constraint} |
| 118 | + ORDER BY id DESC |
| 119 | + """, |
| 120 | + connection=async_conn, |
| 121 | + execute_options=execute_opts, |
| 122 | + ) |
| 123 | + assert_frame_equal(expected_frame, df) |
| 124 | + finally: |
| 125 | + await async_session_inst.close() |
| 126 | + await async_connection.close() |
| 127 | + await async_engine.dispose() |
122 | 128 |
|
123 | | -async def _nested_async_test(tmp_sqlite_db: Path) -> pl.DataFrame: |
124 | | - async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}") |
125 | | - return pl.read_database( |
126 | | - query="SELECT id, name FROM test_data ORDER BY id", |
127 | | - connection=async_engine.connect(), |
128 | | - ) |
| 129 | + asyncio.run(_test_impl()) |
129 | 130 |
|
130 | 131 |
|
131 | 132 | @pytest.mark.skipif( |
132 | 133 | parse_version(sqlalchemy.__version__) < (2, 0), |
133 | 134 | reason="SQLAlchemy 2.0+ required for async tests", |
134 | 135 | ) |
135 | | -def test_read_async_nested(tmp_sqlite_db: Path) -> None: |
136 | | - # This tests validates that we can handle nested async calls |
| 136 | +@pytest.mark.parametrize("started", [True, False]) |
| 137 | +def test_read_async_nested(tmp_sqlite_db: Path, started: bool) -> None: |
| 138 | + # validate that we can handle nested async calls; check |
| 139 | + # this works with connections that are started/unstarted |
| 140 | + async def _test_impl() -> pl.DataFrame: |
| 141 | + async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}") |
| 142 | + async_connection = async_engine.connect() |
| 143 | + if started: |
| 144 | + async_connection = await async_connection |
| 145 | + try: |
| 146 | + return pl.read_database( |
| 147 | + query="SELECT id, name FROM test_data ORDER BY id", |
| 148 | + connection=async_connection, |
| 149 | + ) |
| 150 | + finally: |
| 151 | + await async_connection.close() |
| 152 | + await async_engine.dispose() |
| 153 | + |
137 | 154 | expected_frame = pl.DataFrame({"id": [1, 2], "name": ["misc", "other"]}) |
138 | | - df = asyncio.run(_nested_async_test(tmp_sqlite_db)) |
| 155 | + df = asyncio.run(_test_impl()) |
139 | 156 | assert_frame_equal(expected_frame, df) |
140 | 157 |
|
141 | 158 |
|
@@ -187,7 +204,7 @@ def test_surrealdb_fetchall(batch_size: int | None) -> None: |
187 | 204 |
|
188 | 205 |
|
189 | 206 | def test_async_nested_captured_loop_21263() -> None: |
190 | | - # Tests awaiting a future that has "captured" the original event loop from |
| 207 | + # tests awaiting a future that has "captured" the original event loop from |
191 | 208 | # within a `_run_async` context. |
192 | 209 | async def test_impl() -> None: |
193 | 210 | loop = asyncio.get_running_loop() |
|
0 commit comments