22import typing as t
33
44import sqlalchemy as sa
5- from ellar .common import IExecutionContext , IModuleSetup , Module , middleware
5+ from ellar .common import IHostContext , IModuleSetup , Module
66from ellar .core import Config , DynamicModule , ModuleBase , ModuleSetup
7+ from ellar .core .middleware import as_middleware
8+ from ellar .core .modules import ModuleRefBase
79from ellar .di import ProviderConfig , request_or_transient_scope
8- from ellar .events import app_context_teardown
910from ellar .utils .importer import get_main_directory_by_stack
1011from sqlalchemy .ext .asyncio import (
1112 AsyncEngine ,
@@ -26,32 +27,49 @@ def _raise_exception():
2627 return _raise_exception
2728
2829
29- @Module (commands = [DBCommands ])
30- class EllarSQLModule (ModuleBase , IModuleSetup ):
31- @middleware ()
32- async def session_middleware (
33- cls , context : IExecutionContext , call_next : t .Callable [..., t .Coroutine ]
34- ):
35- connection = context .switch_to_http_connection ().get_client ()
36-
37- db_session = connection .service_provider .get (EllarSQLService )
38- session = db_session .session_factory ()
30+ @as_middleware
31+ async def session_middleware (
32+ context : IHostContext , call_next : t .Callable [..., t .Coroutine ]
33+ ):
34+ connection = context .switch_to_http_connection ().get_client ()
3935
40- connection .state .session = session
36+ db_service = context .get_service_provider ().get (EllarSQLService )
37+ session = db_service .session_factory ()
4138
42- try :
43- await call_next ()
44- except Exception as ex :
45- res = session .rollback ()
46- if isinstance (res , t .Coroutine ):
47- await res
48- raise ex
39+ connection .state .session = session
4940
50- @classmethod
51- async def _on_application_tear_down (cls , db_service : EllarSQLService ) -> None :
52- res = db_service .session_factory .remove ()
41+ try :
42+ await call_next ()
43+ except Exception as ex :
44+ res = session .rollback ()
5345 if isinstance (res , t .Coroutine ):
5446 await res
47+ raise ex
48+
49+ res = db_service .session_factory .remove ()
50+ if isinstance (res , t .Coroutine ):
51+ await res
52+
53+
54+ @Module (
55+ commands = [DBCommands ],
56+ exports = [
57+ EllarSQLService ,
58+ Session ,
59+ AsyncSession ,
60+ AsyncEngine ,
61+ sa .Engine ,
62+ MigrationOption ,
63+ ],
64+ providers = [EllarSQLService ],
65+ name = "EllarSQL" ,
66+ )
67+ class EllarSQLModule (ModuleBase , IModuleSetup ):
68+ @classmethod
69+ def post_build (cls , module_ref : "ModuleRefBase" ) -> None :
70+ module_ref .config .MIDDLEWARE = list (module_ref .config .MIDDLEWARE ) + [
71+ session_middleware
72+ ]
5573
5674 @classmethod
5775 def setup (
@@ -155,8 +173,10 @@ def __setup_module(cls, sql_alchemy_config: SQLAlchemyConfig) -> DynamicModule:
155173 )
156174
157175 providers .append (ProviderConfig (EllarSQLService , use_value = db_service ))
158- app_context_teardown .connect (
159- functools .partial (cls ._on_application_tear_down , db_service = db_service )
176+ providers .append (
177+ ProviderConfig (
178+ MigrationOption , use_value = lambda : db_service .migration_options
179+ )
160180 )
161181
162182 return DynamicModule (
@@ -182,7 +202,7 @@ def register_setup(cls, **override_config: t.Any) -> ModuleSetup:
182202
183203 @staticmethod
184204 def __register_setup_factory (
185- module : t . Type [ "EllarSQLModule" ] ,
205+ module_ref : ModuleRefBase ,
186206 config : Config ,
187207 root_path : str ,
188208 override_config : t .Dict [str , t .Any ],
@@ -201,6 +221,7 @@ def __register_setup_factory(
201221 stack_level = 0 ,
202222 from_dir = defined_config ["root_path" ],
203223 )
224+ module = t .cast (t .Type ["EllarSQLModule" ], module_ref .module )
204225
205226 return module .__setup_module (schema )
206227 raise RuntimeError ("Could not find `ELLAR_SQL` in application config." )
0 commit comments