Skip to content

Commit f5b3878

Browse files
committed
Make session scoped to threads and asyncio tasks
1 parent 79f0f03 commit f5b3878

File tree

4 files changed

+127
-62
lines changed

4 files changed

+127
-62
lines changed
Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
from contextlib import asynccontextmanager, contextmanager
33
from typing import AsyncIterator, Iterator
4-
from uuid import uuid4
54

65
from sqlalchemy.ext.asyncio import (
76
AsyncSession,
@@ -17,91 +16,87 @@
1716

1817

1918
class SessionHandler:
20-
_session_class: scoped_session
21-
session: Session
19+
scoped_session: scoped_session
2220

2321
def __init__(self, bind: SQLAlchemyBind):
2422
if not isinstance(bind, SQLAlchemyBind):
2523
raise UnsupportedBind("Bind is not an instance of SQLAlchemyBind")
2624
else:
27-
u = uuid4()
28-
self._session_class = scoped_session(
29-
bind.session_class, scopefunc=lambda: str(u)
30-
)
31-
self.session = self._session_class()
25+
self.scoped_session = scoped_session(bind.session_class)
3226

3327
def __del__(self):
34-
if getattr(self, "_session_class", None):
35-
self._session_class.remove()
28+
if getattr(self, "scoped_session", None):
29+
self.scoped_session.remove()
3630

3731
@contextmanager
3832
def get_session(self, read_only: bool = False) -> Iterator[Session]:
33+
session = self.scoped_session()
3934
try:
40-
self.session.begin()
41-
yield self.session
35+
session.begin()
36+
yield session
4237
if not read_only:
43-
self.commit()
38+
self.commit(session)
4439
finally:
45-
self.session.close()
40+
session.close()
4641

47-
def commit(self) -> None:
42+
def commit(self, session: Session) -> None:
4843
"""Commits the session and handles rollback on errors.
4944
45+
:param session: The session object.
46+
:type session: Session
5047
:raises Exception: Any error is re-raised after the rollback.
5148
"""
5249
try:
53-
self.session.commit()
50+
session.commit()
5451
except:
55-
self.session.rollback()
52+
session.rollback()
5653
raise
5754

5855

5956
class AsyncSessionHandler:
60-
_session_class: async_scoped_session
61-
session: AsyncSession
57+
scoped_session: async_scoped_session
6258

6359
def __init__(self, bind: SQLAlchemyAsyncBind):
6460
if not isinstance(bind, SQLAlchemyAsyncBind):
6561
raise UnsupportedBind("Bind is not an instance of SQLAlchemyAsyncBind")
6662
else:
67-
u = uuid4()
68-
self._session_class = async_scoped_session(
69-
bind.session_class, scopefunc=lambda: str(u)
63+
self.scoped_session = async_scoped_session(
64+
bind.session_class, asyncio.current_task
7065
)
71-
self.session = self._session_class()
7266

7367
def __del__(self):
74-
if not getattr(self, "_session_class", None):
68+
if not getattr(self, "scoped_session", None):
7569
return
7670

7771
try:
7872
loop = asyncio.get_event_loop()
7973
if loop.is_running():
80-
loop.create_task(self._session_class.remove())
74+
loop.create_task(self.scoped_session.remove())
8175
else:
82-
loop.run_until_complete(self._session_class.remove())
76+
loop.run_until_complete(self.scoped_session.remove())
8377
except RuntimeError:
84-
asyncio.run(self._session_class.remove())
78+
asyncio.run(self.scoped_session.remove())
8579

8680
@asynccontextmanager
8781
async def get_session(self, read_only: bool = False) -> AsyncIterator[AsyncSession]:
82+
session = self.scoped_session()
8883
try:
89-
await self.session.begin()
90-
yield self.session
84+
await session.begin()
85+
yield session
9186
if not read_only:
92-
await self.commit()
87+
await self.commit(session)
9388
finally:
94-
await self.session.close()
89+
await session.close()
9590

96-
async def commit(self) -> None:
91+
async def commit(self, session: AsyncSession) -> None:
9792
"""Commits the session and handles rollback on errors.
9893
9994
:param session: The session object.
100-
:type session: Session
95+
:type session: AsyncSession
10196
:raises Exception: Any error is re-raised after the rollback.
10297
"""
10398
try:
104-
await self.session.commit()
99+
await session.commit()
105100
except:
106-
await self.session.rollback()
101+
await session.rollback()
107102
raise

sqlalchemy_bind_manager/_unit_of_work/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,36 @@
1616

1717

1818
class UnitOfWork:
19-
_transaction_handler: SessionHandler
19+
_session_handler: SessionHandler
2020

2121
def __init__(
2222
self, bind: SQLAlchemyBind, repositories: Iterable[Type[SQLAlchemyRepository]]
2323
) -> None:
2424
super().__init__()
25-
self._transaction_handler = SessionHandler(bind)
25+
self._session_handler = SessionHandler(bind)
2626
for r in repositories:
27-
setattr(self, r.__name__, r(session=self._transaction_handler.session))
27+
setattr(self, r.__name__, r(session=self._session_handler.scoped_session()))
2828

2929
@contextmanager
3030
def transaction(self, read_only: bool = False) -> Iterator[Session]:
31-
with self._transaction_handler.get_session(read_only=read_only) as _s:
31+
with self._session_handler.get_session(read_only=read_only) as _s:
3232
yield _s
3333

3434

3535
class AsyncUnitOfWork:
36-
_transaction_handler: AsyncSessionHandler
36+
_session_handler: AsyncSessionHandler
3737

3838
def __init__(
3939
self,
4040
bind: SQLAlchemyAsyncBind,
4141
repositories: Iterable[Type[SQLAlchemyAsyncRepository]],
4242
) -> None:
4343
super().__init__()
44-
self._transaction_handler = AsyncSessionHandler(bind)
44+
self._session_handler = AsyncSessionHandler(bind)
4545
for r in repositories:
46-
setattr(self, r.__name__, r(session=self._transaction_handler.session))
46+
setattr(self, r.__name__, r(session=self._session_handler.scoped_session()))
4747

4848
@asynccontextmanager
4949
async def transaction(self, read_only: bool = False) -> AsyncIterator[AsyncSession]:
50-
async with self._transaction_handler.get_session(read_only=read_only) as _s:
50+
async with self._session_handler.get_session(read_only=read_only) as _s:
5151
yield _s
Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,50 @@
1+
import asyncio
2+
from multiprocessing.pool import ThreadPool
3+
from time import sleep
14
from unittest.mock import AsyncMock, MagicMock, patch
25

36
import pytest
4-
from sqlalchemy.ext.asyncio import async_scoped_session
5-
from sqlalchemy.orm import scoped_session
7+
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
8+
from sqlalchemy.orm import Session, scoped_session
69

710
from sqlalchemy_bind_manager._bind_manager import SQLAlchemyAsyncBind
811

912

1013
async def test_session_is_removed_on_cleanup(session_handler_class, sa_bind):
11-
uow = session_handler_class(sa_bind)
12-
original_session_remove = uow._session_class.remove
14+
sh = session_handler_class(sa_bind)
15+
original_session_remove = sh.scoped_session.remove
1316

1417
with patch.object(
15-
uow._session_class,
18+
sh.scoped_session,
1619
"remove",
1720
wraps=original_session_remove,
1821
) as mocked_remove:
1922
# This should trigger the garbage collector and close the session
20-
uow = None
23+
sh = None
2124

2225
mocked_remove.assert_called_once()
2326

2427

25-
def test_session_is_removed_on_cleanup_even_if_loop_is_not_running(
28+
async def test_session_is_removed_on_cleanup_even_if_loop_is_not_running(
2629
session_handler_class, sa_bind
2730
):
2831
# This test makes sense only for async implementation
2932
if not isinstance(sa_bind, SQLAlchemyAsyncBind):
3033
return
3134

3235
# Running the test without a loop will trigger the loop creation
33-
uow = session_handler_class(sa_bind)
34-
original_session_remove = uow._session_class.remove
36+
sh = session_handler_class(sa_bind)
37+
original_session_remove = sh.scoped_session.remove
3538

3639
with patch.object(
37-
uow._session_class,
40+
sh.scoped_session,
3841
"remove",
3942
wraps=original_session_remove,
4043
) as mocked_close, patch(
4144
"asyncio.get_event_loop", side_effect=RuntimeError()
4245
) as mocked_get_event_loop:
4346
# This should trigger the garbage collector and close the session
44-
uow = None
47+
sh = None
4548

4649
mocked_get_event_loop.assert_called_once()
4750
mocked_close.assert_called_once()
@@ -55,7 +58,7 @@ async def test_commit_is_called_only_if_not_read_only(
5558
sa_bind,
5659
sync_async_cm_wrapper,
5760
):
58-
uow = session_handler_class(sa_bind)
61+
sh = session_handler_class(sa_bind)
5962

6063
# Populate a database entry to be used for tests
6164
model1 = model_class(
@@ -64,13 +67,13 @@ async def test_commit_is_called_only_if_not_read_only(
6467

6568
with patch.object(
6669
session_handler_class, "commit", return_value=None
67-
) as mocked_uow_commit:
70+
) as mocked_sh_commit:
6871
async with sync_async_cm_wrapper(
69-
uow.get_session(read_only=read_only_flag)
72+
sh.get_session(read_only=read_only_flag)
7073
) as _session:
7174
_session.add(model1)
7275

73-
assert mocked_uow_commit.call_count == int(not read_only_flag)
76+
assert mocked_sh_commit.call_count == int(not read_only_flag)
7477

7578

7679
@pytest.mark.parametrize("commit_fails", [True, False])
@@ -80,23 +83,88 @@ async def test_rollback_is_called_if_commit_fails(
8083
sa_bind,
8184
sync_async_wrapper,
8285
):
83-
uow = session_handler_class(sa_bind)
86+
sh = session_handler_class(sa_bind)
8487

8588
failure_exception = Exception("Some Error")
8689
mocked_session = (
8790
AsyncMock(spec=async_scoped_session)
8891
if isinstance(sa_bind, SQLAlchemyAsyncBind)
8992
else MagicMock(spec=scoped_session)
9093
)
91-
uow.session = mocked_session
9294
if commit_fails:
9395
mocked_session.commit.side_effect = failure_exception
9496

9597
try:
96-
await sync_async_wrapper(uow.commit())
98+
await sync_async_wrapper(sh.commit(mocked_session))
9799
except Exception as e:
98100
assert commit_fails is True
99101
assert e == failure_exception
100102

101103
assert mocked_session.commit.call_count == 1
102104
assert mocked_session.rollback.call_count == int(commit_fails)
105+
106+
107+
async def test_session_is_different_on_different_asyncio_tasks(
108+
session_handler_class, sa_bind
109+
):
110+
# This test makes sense only for async implementation
111+
if not isinstance(sa_bind, SQLAlchemyAsyncBind):
112+
return
113+
114+
# Running the test without a loop will trigger the loop creation
115+
sh = session_handler_class(sa_bind)
116+
117+
s1 = sh.scoped_session()
118+
s2 = sh.scoped_session()
119+
120+
assert isinstance(s1, AsyncSession)
121+
assert isinstance(s2, AsyncSession)
122+
assert s1 is s2
123+
124+
async def _get_sh_session():
125+
return sh.scoped_session()
126+
127+
s = await asyncio.gather(
128+
_get_sh_session(),
129+
_get_sh_session(),
130+
)
131+
132+
assert isinstance(s[0], AsyncSession)
133+
assert isinstance(s[1], AsyncSession)
134+
assert s[0] is not s[1]
135+
136+
137+
async def test_session_is_different_on_different_threads(
138+
session_handler_class, sa_bind
139+
):
140+
# This test makes sense only for sync implementation
141+
if isinstance(sa_bind, SQLAlchemyAsyncBind):
142+
return
143+
144+
# Running the test without a loop will trigger the loop creation
145+
sh = session_handler_class(sa_bind)
146+
147+
s1 = sh.scoped_session()
148+
s2 = sh.scoped_session()
149+
150+
assert isinstance(s1, Session)
151+
assert isinstance(s2, Session)
152+
assert s1 is s2
153+
154+
def _get_session():
155+
# This sleep is to make sure the task doesn't
156+
# resolve immediately and multiple instances
157+
# end up in different threads
158+
sleep(1)
159+
return sh.scoped_session()
160+
161+
with ThreadPool() as pool:
162+
s3_task = pool.apply_async(_get_session)
163+
s4_task = pool.apply_async(_get_session)
164+
165+
s3 = s3_task.get()
166+
s4 = s4_task.get()
167+
168+
assert isinstance(s3, Session)
169+
assert isinstance(s4, Session)
170+
assert s3 is not s4

tests/unit_of_work/test_lifecycle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ class ChildRepoClass(repository_class):
1717
repo = getattr(uow, repo_class.__name__)
1818
assert not hasattr(repo, "_session_handler")
1919
assert hasattr(repo, "_external_session")
20-
assert getattr(repo, "_external_session") is uow._transaction_handler.session
20+
assert (
21+
getattr(repo, "_external_session") is uow._session_handler.scoped_session()
22+
)

0 commit comments

Comments
 (0)