Skip to content
This repository was archived by the owner on Mar 19, 2026. It is now read-only.

Commit 5ce03d1

Browse files
authored
Merge pull request #70 from obendidi/sqlalchemy_v2
sqlalchemy v2 (fix tests)
2 parents f47bf72 + 84bcb06 commit 5ce03d1

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

prefect_sqlalchemy/database.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
else:
1717
from pydantic import AnyUrl, Field, SecretStr
1818

19+
from sqlalchemy import __version__ as SQLALCHEMY_VERSION
1920
from sqlalchemy.engine import Connection, Engine, create_engine
2021
from sqlalchemy.engine.cursor import CursorResult
2122
from sqlalchemy.engine.url import URL, make_url
@@ -67,6 +68,8 @@ async def _execute(
6768
if async_supported:
6869
result = await result
6970
await connection.commit()
71+
elif SQLALCHEMY_VERSION.startswith("2."):
72+
connection.commit()
7073
return result
7174

7275

@@ -497,6 +500,8 @@ async def _async_sync_execute(
497500
if self._driver_is_async:
498501
result_set = await result_set
499502
await connection.commit() # very important
503+
elif SQLALCHEMY_VERSION.startswith("2."):
504+
connection.commit()
500505
return result_set
501506

502507
@asynccontextmanager

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
prefect>=2.13.5
2-
sqlalchemy>=1.4.31,<2
2+
sqlalchemy>=1.4.31,<3

tests/test_database.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import cloudpickle
55
import pytest
66
from prefect import flow, task
7+
from sqlalchemy import __version__ as SQLALCHEMY_VERSION
78
from sqlalchemy.engine import Connection, Engine
89
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
910

@@ -88,6 +89,9 @@ def execute(self, query, params):
8889
)
8990
return cursor_result
9091

92+
def commit(self):
93+
pass
94+
9195

9296
@pytest.fixture()
9397
def sqlalchemy_credentials_async():
@@ -340,13 +344,21 @@ async def managed_connector_with_data(self, connector_with_data, request):
340344
assert connector_with_data._engine is None
341345

342346
@pytest.mark.parametrize("begin", [True, False])
343-
def test_get_connection(self, begin, managed_connector_with_data):
347+
async def test_get_connection(self, begin, managed_connector_with_data):
344348
connection = managed_connector_with_data.get_connection(begin=begin)
345349
if begin:
346350
engine_type = (
347351
AsyncEngine if managed_connector_with_data._driver_is_async else Engine
348352
)
349-
assert isinstance(connection, engine_type._trans_ctx)
353+
354+
if SQLALCHEMY_VERSION.startswith("1."):
355+
assert isinstance(connection, engine_type._trans_ctx)
356+
elif managed_connector_with_data._driver_is_async:
357+
async with connection as conn:
358+
assert isinstance(conn, engine_type._connection_cls)
359+
else:
360+
with connection as conn:
361+
assert isinstance(conn, engine_type._connection_cls)
350362
else:
351363
engine_type = (
352364
AsyncConnection
@@ -356,15 +368,22 @@ def test_get_connection(self, begin, managed_connector_with_data):
356368
assert isinstance(connection, engine_type)
357369

358370
@pytest.mark.parametrize("begin", [True, False])
359-
def test_get_client(self, begin, managed_connector_with_data):
371+
async def test_get_client(self, begin, managed_connector_with_data):
360372
connection = managed_connector_with_data.get_client(
361373
client_type="connection", begin=begin
362374
)
363375
if begin:
364376
engine_type = (
365377
AsyncEngine if managed_connector_with_data._driver_is_async else Engine
366378
)
367-
assert isinstance(connection, engine_type._trans_ctx)
379+
if SQLALCHEMY_VERSION.startswith("1."):
380+
assert isinstance(connection, engine_type._trans_ctx)
381+
elif managed_connector_with_data._driver_is_async:
382+
async with connection as conn:
383+
assert isinstance(conn, engine_type._connection_cls)
384+
else:
385+
with connection as conn:
386+
assert isinstance(conn, engine_type._connection_cls)
368387
else:
369388
engine_type = (
370389
AsyncConnection

0 commit comments

Comments
 (0)