44import cloudpickle
55import pytest
66from prefect import flow , task
7+ from sqlalchemy import __version__ as SQLALCHEMY_VERSION
78from sqlalchemy .engine import Connection , Engine
89from sqlalchemy .ext .asyncio import AsyncConnection , AsyncEngine
910
@@ -88,6 +89,9 @@ def execute(self, query, params):
8889 )
8990 return cursor_result
9091
92+ def commit (self ):
93+ pass
94+
9195
9296@pytest .fixture ()
9397def sqlalchemy_credentials_async ():
@@ -340,13 +344,21 @@ async def managed_connector_with_data(self, connector_with_data, request):
340344 assert connector_with_data ._engine is None
341345
342346 @pytest .mark .parametrize ("begin" , [True , False ])
343- def test_get_connection (self , begin , managed_connector_with_data ):
347+ async def test_get_connection (self , begin , managed_connector_with_data ):
344348 connection = managed_connector_with_data .get_connection (begin = begin )
345349 if begin :
346350 engine_type = (
347351 AsyncEngine if managed_connector_with_data ._driver_is_async else Engine
348352 )
349- assert isinstance (connection , engine_type ._trans_ctx )
353+
354+ if SQLALCHEMY_VERSION .startswith ("1." ):
355+ assert isinstance (connection , engine_type ._trans_ctx )
356+ elif managed_connector_with_data ._driver_is_async :
357+ async with connection as conn :
358+ assert isinstance (conn , engine_type ._connection_cls )
359+ else :
360+ with connection as conn :
361+ assert isinstance (conn , engine_type ._connection_cls )
350362 else :
351363 engine_type = (
352364 AsyncConnection
@@ -356,15 +368,22 @@ def test_get_connection(self, begin, managed_connector_with_data):
356368 assert isinstance (connection , engine_type )
357369
358370 @pytest .mark .parametrize ("begin" , [True , False ])
359- def test_get_client (self , begin , managed_connector_with_data ):
371+ async def test_get_client (self , begin , managed_connector_with_data ):
360372 connection = managed_connector_with_data .get_client (
361373 client_type = "connection" , begin = begin
362374 )
363375 if begin :
364376 engine_type = (
365377 AsyncEngine if managed_connector_with_data ._driver_is_async else Engine
366378 )
367- assert isinstance (connection , engine_type ._trans_ctx )
379+ if SQLALCHEMY_VERSION .startswith ("1." ):
380+ assert isinstance (connection , engine_type ._trans_ctx )
381+ elif managed_connector_with_data ._driver_is_async :
382+ async with connection as conn :
383+ assert isinstance (conn , engine_type ._connection_cls )
384+ else :
385+ with connection as conn :
386+ assert isinstance (conn , engine_type ._connection_cls )
368387 else :
369388 engine_type = (
370389 AsyncConnection
0 commit comments