1+ import asyncio
2+ from multiprocessing .pool import ThreadPool
3+ from time import sleep
14from unittest .mock import AsyncMock , MagicMock , patch
25
36import pytest
4- from sqlalchemy .ext .asyncio import async_scoped_session
5- from sqlalchemy .orm import scoped_session
7+ from sqlalchemy .ext .asyncio import AsyncSession , async_scoped_session
8+ from sqlalchemy .orm import Session , scoped_session
69
710from sqlalchemy_bind_manager ._bind_manager import SQLAlchemyAsyncBind
811
912
1013async def test_session_is_removed_on_cleanup (session_handler_class , sa_bind ):
11- uow = session_handler_class (sa_bind )
12- original_session_remove = uow . _session_class .remove
14+ sh = session_handler_class (sa_bind )
15+ original_session_remove = sh . scoped_session .remove
1316
1417 with patch .object (
15- uow . _session_class ,
18+ sh . scoped_session ,
1619 "remove" ,
1720 wraps = original_session_remove ,
1821 ) as mocked_remove :
1922 # This should trigger the garbage collector and close the session
20- uow = None
23+ sh = None
2124
2225 mocked_remove .assert_called_once ()
2326
2427
25- def test_session_is_removed_on_cleanup_even_if_loop_is_not_running (
28+ async def test_session_is_removed_on_cleanup_even_if_loop_is_not_running (
2629 session_handler_class , sa_bind
2730):
2831 # This test makes sense only for async implementation
2932 if not isinstance (sa_bind , SQLAlchemyAsyncBind ):
3033 return
3134
3235 # Running the test without a loop will trigger the loop creation
33- uow = session_handler_class (sa_bind )
34- original_session_remove = uow . _session_class .remove
36+ sh = session_handler_class (sa_bind )
37+ original_session_remove = sh . scoped_session .remove
3538
3639 with patch .object (
37- uow . _session_class ,
40+ sh . scoped_session ,
3841 "remove" ,
3942 wraps = original_session_remove ,
4043 ) as mocked_close , patch (
4144 "asyncio.get_event_loop" , side_effect = RuntimeError ()
4245 ) as mocked_get_event_loop :
4346 # This should trigger the garbage collector and close the session
44- uow = None
47+ sh = None
4548
4649 mocked_get_event_loop .assert_called_once ()
4750 mocked_close .assert_called_once ()
@@ -55,7 +58,7 @@ async def test_commit_is_called_only_if_not_read_only(
5558 sa_bind ,
5659 sync_async_cm_wrapper ,
5760):
58- uow = session_handler_class (sa_bind )
61+ sh = session_handler_class (sa_bind )
5962
6063 # Populate a database entry to be used for tests
6164 model1 = model_class (
@@ -64,13 +67,13 @@ async def test_commit_is_called_only_if_not_read_only(
6467
6568 with patch .object (
6669 session_handler_class , "commit" , return_value = None
67- ) as mocked_uow_commit :
70+ ) as mocked_sh_commit :
6871 async with sync_async_cm_wrapper (
69- uow .get_session (read_only = read_only_flag )
72+ sh .get_session (read_only = read_only_flag )
7073 ) as _session :
7174 _session .add (model1 )
7275
73- assert mocked_uow_commit .call_count == int (not read_only_flag )
76+ assert mocked_sh_commit .call_count == int (not read_only_flag )
7477
7578
7679@pytest .mark .parametrize ("commit_fails" , [True , False ])
@@ -80,23 +83,88 @@ async def test_rollback_is_called_if_commit_fails(
8083 sa_bind ,
8184 sync_async_wrapper ,
8285):
83- uow = session_handler_class (sa_bind )
86+ sh = session_handler_class (sa_bind )
8487
8588 failure_exception = Exception ("Some Error" )
8689 mocked_session = (
8790 AsyncMock (spec = async_scoped_session )
8891 if isinstance (sa_bind , SQLAlchemyAsyncBind )
8992 else MagicMock (spec = scoped_session )
9093 )
91- uow .session = mocked_session
9294 if commit_fails :
9395 mocked_session .commit .side_effect = failure_exception
9496
9597 try :
96- await sync_async_wrapper (uow .commit ())
98+ await sync_async_wrapper (sh .commit (mocked_session ))
9799 except Exception as e :
98100 assert commit_fails is True
99101 assert e == failure_exception
100102
101103 assert mocked_session .commit .call_count == 1
102104 assert mocked_session .rollback .call_count == int (commit_fails )
105+
106+
107+ async def test_session_is_different_on_different_asyncio_tasks (
108+ session_handler_class , sa_bind
109+ ):
110+ # This test makes sense only for async implementation
111+ if not isinstance (sa_bind , SQLAlchemyAsyncBind ):
112+ return
113+
114+ # Running the test without a loop will trigger the loop creation
115+ sh = session_handler_class (sa_bind )
116+
117+ s1 = sh .scoped_session ()
118+ s2 = sh .scoped_session ()
119+
120+ assert isinstance (s1 , AsyncSession )
121+ assert isinstance (s2 , AsyncSession )
122+ assert s1 is s2
123+
124+ async def _get_sh_session ():
125+ return sh .scoped_session ()
126+
127+ s = await asyncio .gather (
128+ _get_sh_session (),
129+ _get_sh_session (),
130+ )
131+
132+ assert isinstance (s [0 ], AsyncSession )
133+ assert isinstance (s [1 ], AsyncSession )
134+ assert s [0 ] is not s [1 ]
135+
136+
137+ async def test_session_is_different_on_different_threads (
138+ session_handler_class , sa_bind
139+ ):
140+ # This test makes sense only for sync implementation
141+ if isinstance (sa_bind , SQLAlchemyAsyncBind ):
142+ return
143+
144+ # Running the test without a loop will trigger the loop creation
145+ sh = session_handler_class (sa_bind )
146+
147+ s1 = sh .scoped_session ()
148+ s2 = sh .scoped_session ()
149+
150+ assert isinstance (s1 , Session )
151+ assert isinstance (s2 , Session )
152+ assert s1 is s2
153+
154+ def _get_session ():
155+ # This sleep is to make sure the task doesn't
156+ # resolve immediately and multiple instances
157+ # end up in different threads
158+ sleep (1 )
159+ return sh .scoped_session ()
160+
161+ with ThreadPool () as pool :
162+ s3_task = pool .apply_async (_get_session )
163+ s4_task = pool .apply_async (_get_session )
164+
165+ s3 = s3_task .get ()
166+ s4 = s4_task .get ()
167+
168+ assert isinstance (s3 , Session )
169+ assert isinstance (s4 , Session )
170+ assert s3 is not s4
0 commit comments