88from sqlalchemy .orm import Session , scoped_session
99
1010from sqlalchemy_bind_manager ._bind_manager import SQLAlchemyAsyncBind
11+ from sqlalchemy_bind_manager ._session_handler import AsyncSessionHandler , SessionHandler
1112
1213
1314async def test_session_is_removed_on_cleanup (session_handler_class , sa_bind ):
@@ -25,15 +26,9 @@ async def test_session_is_removed_on_cleanup(session_handler_class, sa_bind):
2526 mocked_remove .assert_called_once ()
2627
2728
28- async def test_session_is_removed_on_cleanup_even_if_loop_is_not_running (
29- session_handler_class , sa_bind
30- ):
31- # This test makes sense only for async implementation
32- if not isinstance (sa_bind , SQLAlchemyAsyncBind ):
33- return
34-
29+ async def test_session_is_removed_on_cleanup_even_if_loop_is_not_running (sa_manager ):
3530 # Running the test without a loop will trigger the loop creation
36- sh = session_handler_class ( sa_bind )
31+ sh = AsyncSessionHandler ( sa_manager . get_bind ( "async" ) )
3732 original_session_remove = sh .scoped_session .remove
3833
3934 with patch .object (
@@ -104,15 +99,9 @@ async def test_rollback_is_called_if_commit_fails(
10499 assert mocked_session .rollback .call_count == int (commit_fails )
105100
106101
107- async def test_session_is_different_on_different_asyncio_tasks (
108- session_handler_class , sa_bind
109- ):
110- # This test makes sense only for async implementation
111- if not isinstance (sa_bind , SQLAlchemyAsyncBind ):
112- return
113-
102+ async def test_session_is_different_on_different_asyncio_tasks (sa_manager ):
114103 # Running the test without a loop will trigger the loop creation
115- sh = session_handler_class ( sa_bind )
104+ sh = AsyncSessionHandler ( sa_manager . get_bind ( "async" ) )
116105
117106 s1 = sh .scoped_session ()
118107 s2 = sh .scoped_session ()
@@ -134,15 +123,9 @@ async def _get_sh_session():
134123 assert s [0 ] is not s [1 ]
135124
136125
137- async def test_session_is_different_on_different_threads (
138- session_handler_class , sa_bind
139- ):
140- # This test makes sense only for sync implementation
141- if isinstance (sa_bind , SQLAlchemyAsyncBind ):
142- return
143-
126+ async def test_session_is_different_on_different_threads (sa_manager ):
144127 # Running the test without a loop will trigger the loop creation
145- sh = session_handler_class ( sa_bind )
128+ sh = SessionHandler ( sa_manager . get_bind ( "sync" ) )
146129
147130 s1 = sh .scoped_session ()
148131 s2 = sh .scoped_session ()
0 commit comments