Skip to content

Commit c94e2fd

Browse files
committed
drop async session scope. Allow user to manage session connection scope and also manage closure when app has a response
1 parent 140290e commit c94e2fd

File tree

6 files changed

+74
-77
lines changed

6 files changed

+74
-77
lines changed

ellar_sql/module.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,32 @@ def _raise_exception():
2727
return _raise_exception
2828

2929

30-
async def _session_cleanup(
31-
db_service: EllarSQLService, session: t.Union[Session, AsyncSession]
32-
) -> None:
33-
res = session.close()
34-
if isinstance(res, t.Coroutine):
35-
await res
36-
37-
res = db_service.session_factory.remove()
38-
if isinstance(res, t.Coroutine):
39-
await res
40-
41-
4230
@as_middleware
4331
async def session_middleware(
4432
context: IHostContext, call_next: t.Callable[..., t.Coroutine]
4533
):
4634
connection = context.switch_to_http_connection().get_client()
47-
4835
db_service = context.get_service_provider().get(EllarSQLService)
49-
session = db_service.session_factory()
5036

37+
# Create a NEW session for this request
38+
session = db_service.session_factory()
5139
connection.state.session = session
5240

5341
try:
5442
await call_next()
43+
except Exception as ex:
44+
# Only rollback if session is still active
45+
if session.is_active and session.in_transaction():
46+
res = session.rollback()
47+
if isinstance(res, t.Coroutine):
48+
await res
49+
raise ex
5550
finally:
56-
await _session_cleanup(db_service, session)
51+
# Always clean up
52+
if session.is_active:
53+
res = session.close()
54+
if isinstance(res, t.Coroutine):
55+
await res
5756

5857

5958
@Module(

ellar_sql/pagination/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ async def _close_session(self) -> None:
286286
def _get_session(self) -> t.Union[sa_orm.Session, AsyncSession, t.Any]:
287287
self._created_session = True
288288
service = current_injector.get(EllarSQLService)
289-
return service.get_scoped_session()()
289+
return service.session_factory_maker()()
290290

291291
def _query_items(self) -> t.List[t.Any]:
292292
if self._is_async:

ellar_sql/query/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ async def get_or_404(
1919
) -> _O:
2020
""" """
2121
db_service = current_injector.get(EllarSQLService)
22-
session = db_service.get_scoped_session()()
22+
session = db_service.session_factory_maker()()
2323

2424
value = session.get(entity, ident, **kwargs)
2525

@@ -39,7 +39,7 @@ async def get_or_none(
3939
) -> t.Optional[_O]:
4040
""" """
4141
db_service = current_injector.get(EllarSQLService)
42-
session = db_service.get_scoped_session()()
42+
session = db_service.session_factory_maker()()
4343

4444
value = session.get(entity, ident, **kwargs)
4545

ellar_sql/services/base.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import typing as t
3-
from threading import get_ident
43
from weakref import WeakKeyDictionary
54

65
import sqlalchemy as sa
@@ -14,7 +13,6 @@
1413
)
1514
from sqlalchemy.ext.asyncio import (
1615
AsyncSession,
17-
async_scoped_session,
1816
async_sessionmaker,
1917
)
2018

@@ -32,6 +30,10 @@
3230

3331

3432
class EllarSQLService:
33+
session_factory: t.Union[
34+
sa_orm.sessionmaker[sa_orm.Session], async_sessionmaker[AsyncSession]
35+
]
36+
3537
def __init__(
3638
self,
3739
databases: t.Union[str, t.Dict[str, t.Any]],
@@ -61,7 +63,7 @@ def __init__(
6163
self._has_async_engine_driver: bool = False
6264

6365
self._setup(databases, models=models, echo=echo)
64-
self.session_factory = self.get_scoped_session()
66+
self.session_factory = self.session_factory_maker()
6567

6668
@property
6769
def has_async_engine_driver(self) -> bool:
@@ -177,24 +179,16 @@ def reflect(self, *databases: str) -> None:
177179
continue
178180
metadata_engine.reflect()
179181

180-
def get_scoped_session(
182+
def session_factory_maker(
181183
self,
182184
**extra_options: t.Any,
183-
) -> t.Union[
184-
sa_orm.scoped_session[sa_orm.Session],
185-
async_scoped_session[t.Union[AsyncSession, t.Any]],
186-
]:
185+
) -> t.Union[sa_orm.sessionmaker[sa_orm.Session], async_sessionmaker[AsyncSession]]:
187186
options = self._session_options.copy()
188187
options.update(extra_options)
189188

190-
scope = options.pop("scopefunc", get_ident)
191-
192-
factory = self._make_session_factory(options)
193-
194-
if self.has_async_engine_driver:
195-
return async_scoped_session(factory, scope) # type:ignore[arg-type]
189+
scope = options.pop("scopefunc", None) # noqa: F841
196190

197-
return sa_orm.scoped_session(factory, scope) # type:ignore[arg-type]
191+
return self._make_session_factory(options)
198192

199193
def _make_session_factory(
200194
self, options: t.Dict[str, t.Any]

tests/test_model_export.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_model_export_without_filter(self, db_service, ignore_base):
4343
"id": 1,
4444
"name": "Ellar",
4545
}
46-
db_service.session_factory.close()
46+
# db_service.session_factory.close()
4747

4848
def test_model_exclude_none(self, db_service, ignore_base):
4949
user_factory = get_model_factory(db_service)
@@ -59,7 +59,7 @@ def test_model_exclude_none(self, db_service, ignore_base):
5959
"id": 1,
6060
"name": "Ellar",
6161
}
62-
db_service.session_factory.close()
62+
# db_service.session_factory.close()
6363

6464
def test_model_export_include(self, db_service, ignore_base):
6565
user_factory = get_model_factory(db_service)
@@ -73,7 +73,7 @@ def test_model_export_include(self, db_service, ignore_base):
7373
"id",
7474
"name",
7575
}
76-
db_service.session_factory.close()
76+
# db_service.session_factory.close()
7777

7878
def test_model_export_exclude(self, db_service, ignore_base):
7979
user_factory = get_model_factory(db_service)
@@ -83,7 +83,7 @@ def test_model_export_exclude(self, db_service, ignore_base):
8383
user = user_factory()
8484

8585
assert user.dict(exclude={"email", "name"}).keys() == {"address", "city", "id"}
86-
db_service.session_factory.close()
86+
# db_service.session_factory.close()
8787

8888

8989
@pytest.mark.asyncio
@@ -105,7 +105,6 @@ async def test_model_export_without_filter_async(
105105
"id": 1,
106106
"name": "Ellar",
107107
}
108-
await db_service_async.session_factory.close()
109108

110109
async def test_model_exclude_none_async(self, db_service_async, ignore_base):
111110
user_factory = get_model_factory(db_service_async)
@@ -121,7 +120,6 @@ async def test_model_exclude_none_async(self, db_service_async, ignore_base):
121120
"id": 1,
122121
"name": "Ellar",
123122
}
124-
await db_service_async.session_factory.close()
125123

126124
async def test_model_export_include_async(self, db_service_async, ignore_base):
127125
user_factory = get_model_factory(db_service_async)
@@ -135,7 +133,6 @@ async def test_model_export_include_async(self, db_service_async, ignore_base):
135133
"id",
136134
"name",
137135
}
138-
await db_service_async.session_factory.close()
139136

140137
async def test_model_export_exclude_async(self, db_service_async, ignore_base):
141138
user_factory = get_model_factory(db_service_async)
@@ -145,4 +142,3 @@ async def test_model_export_exclude_async(self, db_service_async, ignore_base):
145142
user = user_factory()
146143

147144
assert user.dict(exclude={"email", "name"}).keys() == {"address", "city", "id"}
148-
await db_service_async.session_factory.close()

0 commit comments

Comments
 (0)