|
3 | 3 |
|
4 | 4 | import cloudpickle |
5 | 5 | import pytest |
6 | | -from prefect import flow |
| 6 | +from prefect import flow, task |
7 | 7 | from sqlalchemy.engine import Connection, Engine |
8 | 8 | from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine |
9 | 9 |
|
@@ -261,22 +261,30 @@ async def test_connector_init(self): |
261 | 261 | credentials_url = SqlAlchemyConnector(connection_info=connection_url) |
262 | 262 | assert credentials_components._rendered_url == credentials_url._rendered_url |
263 | 263 |
|
264 | | - def test_delay_start(self, caplog): |
| 264 | + @pytest.mark.parametrize("method", ["fetch_all", "execute"]) |
| 265 | + def test_delay_start(self, caplog, method): |
265 | 266 | with SqlAlchemyConnector( |
266 | 267 | connection_info=ConnectionComponents( |
267 | 268 | driver=SyncDriver.SQLITE_PYSQLITE, |
268 | 269 | database=":memory:", |
269 | 270 | ), |
270 | 271 | ) as connector: |
| 272 | + assert connector._unique_results == {} |
| 273 | + assert isinstance(connector._exit_stack, ExitStack) |
271 | 274 | connector.reset_connections() |
272 | | - assert caplog.records[0].msg == "There were no connections to reset." |
| 275 | + assert ( |
| 276 | + caplog.records[0].msg == "Reset opened connections and their results." |
| 277 | + ) |
273 | 278 | assert connector._engine is None |
274 | | - assert connector._unique_results is None |
275 | | - assert connector._exit_stack is None |
276 | | - connector.execute("SELECT 1") |
277 | | - assert isinstance(connector._engine, Engine) |
278 | 279 | assert connector._unique_results == {} |
279 | 280 | assert isinstance(connector._exit_stack, ExitStack) |
| 281 | + getattr(connector, method)("SELECT 1") |
| 282 | + assert isinstance(connector._engine, Engine) |
| 283 | + if method == "execute": |
| 284 | + assert connector._unique_results == {} |
| 285 | + else: |
| 286 | + assert len(connector._unique_results) == 1 |
| 287 | + assert isinstance(connector._exit_stack, ExitStack) |
280 | 288 |
|
281 | 289 | @pytest.fixture(params=[SyncDriver.SQLITE_PYSQLITE, AsyncDriver.SQLITE_AIOSQLITE]) |
282 | 290 | async def connector_with_data(self, tmp_path, request): |
@@ -547,3 +555,52 @@ def test_sync_compatible_reset_connections(self, tmp_path): |
547 | 555 | assert len(conn._unique_results) == 1 |
548 | 556 | conn.reset_connections() |
549 | 557 | assert len(conn._unique_results) == 0 |
| 558 | + |
| 559 | + def test_flow_without_initialized_engine(self, tmp_path): |
| 560 | + @task |
| 561 | + def setup_table(block_name: str) -> None: |
| 562 | + with SqlAlchemyConnector.load(block_name) as connector: |
| 563 | + connector.execute( |
| 564 | + "CREATE TABLE IF NOT EXISTS customers (name varchar, address varchar);" # noqa |
| 565 | + ) |
| 566 | + connector.execute( |
| 567 | + "INSERT INTO customers (name, address) VALUES (:name, :address);", |
| 568 | + parameters={"name": "Marvin", "address": "Highway 42"}, |
| 569 | + ) |
| 570 | + connector.execute_many( |
| 571 | + "INSERT INTO customers (name, address) VALUES (:name, :address);", |
| 572 | + seq_of_parameters=[ |
| 573 | + {"name": "Ford", "address": "Highway 42"}, |
| 574 | + {"name": "Unknown", "address": "Highway 42"}, |
| 575 | + ], |
| 576 | + ) |
| 577 | + |
| 578 | + @task |
| 579 | + def fetch_data(block_name: str) -> list: |
| 580 | + all_rows = [] |
| 581 | + with SqlAlchemyConnector.load(block_name) as connector: |
| 582 | + while True: |
| 583 | + # Repeated fetch* calls using the same operation will |
| 584 | + # skip re-executing and instead return the next set of results |
| 585 | + new_rows = connector.fetch_many("SELECT * FROM customers", size=2) |
| 586 | + if len(new_rows) == 0: |
| 587 | + break |
| 588 | + all_rows.append(new_rows) |
| 589 | + return all_rows |
| 590 | + |
| 591 | + @flow |
| 592 | + def sqlalchemy_flow(block_name: str) -> list: |
| 593 | + SqlAlchemyConnector( |
| 594 | + connection_info=ConnectionComponents( |
| 595 | + driver=SyncDriver.SQLITE_PYSQLITE, |
| 596 | + database=str(tmp_path / "test.db"), |
| 597 | + ) |
| 598 | + ).save(block_name) |
| 599 | + setup_table(block_name) |
| 600 | + all_rows = fetch_data(block_name) |
| 601 | + return all_rows |
| 602 | + |
| 603 | + assert sqlalchemy_flow("connector") == [ |
| 604 | + [("Marvin", "Highway 42"), ("Ford", "Highway 42")], |
| 605 | + [("Unknown", "Highway 42")], |
| 606 | + ] |
0 commit comments