11from __future__ import annotations
22
33from collections .abc import Sequence
4- from typing import Literal , ClassVar
4+ from typing import Literal , ClassVar , Any , TypeVar , cast
55
66from launart import Launart , Service
77from loguru import logger
8- from sqlalchemy .engine .result import Result
9- from sqlalchemy .engine .url import URL
10- from sqlalchemy .ext .asyncio import AsyncSession , async_sessionmaker
8+ from sqlalchemy import Table
9+
10+ from sqlalchemy .ext .asyncio import AsyncSession , async_sessionmaker , create_async_engine
11+ from sqlalchemy .ext .asyncio .engine import AsyncEngine
1112from sqlalchemy .sql .base import Executable
1213from sqlalchemy .sql .selectable import TypedReturnsRows
14+ from sqlalchemy .engine .result import Result
15+ from sqlalchemy .engine .url import URL
1316from sqlalchemy .orm import DeclarativeBase
1417
15- from .manager import DatabaseManager , T_Row
1618from .model import Base
1719from .types import EngineOptions
20+ from .utils import get_subclasses
21+
22+ T_Row = TypeVar ("T_Row" , bound = DeclarativeBase )
1823
1924
2025class SqlalchemyService (Service ):
2126 id : str = "database/sqlalchemy"
22- db : DatabaseManager
23- get_session : async_sessionmaker [AsyncSession ]
2427 base_class : ClassVar [type [DeclarativeBase ]] = Base
28+ engines : dict [str , AsyncEngine ]
29+ session_factory : async_sessionmaker [AsyncSession ]
2530
2631 def __init__ (
2732 self ,
2833 url : str | URL ,
2934 engine_options : EngineOptions | None = None ,
35+ session_options : dict [str , Any ] | None = None ,
36+ binds : dict [str , str | URL ] | None = None ,
3037 create_table_at : Literal ["preparing" , "prepared" , "blocking" ] = "preparing"
3138 ) -> None :
32- self .db = DatabaseManager (url , engine_options )
39+ if engine_options is None :
40+ engine_options = {"echo" : "debug" , "pool_pre_ping" : True }
41+ self .engines ["" ] = create_async_engine (url , ** engine_options )
42+ for key , bind_url in (binds or {}).items ():
43+ self .engines [key ] = create_async_engine (bind_url , ** engine_options )
3344 self .create_table_at = create_table_at
45+ self .session_options = session_options or {"expire_on_commit" : False }
3446 super ().__init__ ()
3547
3648 @property
@@ -41,51 +53,96 @@ def required(self) -> set[str]:
4153 def stages (self ) -> set [Literal ["preparing" , "blocking" , "cleanup" ]]:
4254 return {"preparing" , "blocking" , "cleanup" }
4355
56+ async def initialize (self ):
57+ binds = {}
58+
59+ for model in set (get_subclasses (self .base_class )):
60+ table : Table | None = getattr (model , "__table__" , None )
61+
62+ if table is None or (bind_key := table .info .get ("bind_key" )) is None :
63+ continue
64+
65+ binds [model ] = self .engines .get (bind_key , self .engines ["" ])
66+
67+ self .session_factory = async_sessionmaker (self .engines ["" ], binds = binds , ** self .session_options )
68+ return binds
69+
70+ def get_session (self , ** local_kw ):
71+ return self .session_factory (** local_kw )
72+
4473 async def launch (self , manager : Launart ):
74+ binds : dict [type [Base ], AsyncEngine ] = {}
75+
4576 async with self .stage ("preparing" ):
4677 logger .info ("Initializing database..." )
47- await self .db .initialize ()
48- self .get_session = self .db .session_factory
49- logger .success ("Database initialized!" )
5078 if self .create_table_at == "preparing" :
51- async with self .db .engine .begin () as conn :
52- await conn .run_sync (self .base_class .metadata .create_all )
53- logger .success ("Database tables created!" )
79+ binds = await self .initialize ()
80+ logger .success ("Database initialized!" )
81+ for model , engine in binds .items ():
82+ async with engine .begin () as conn :
83+ await conn .run_sync (model .__table__ .create , checkfirst = True )
84+ logger .success ("Database tables created!" )
5485
86+ if self .create_table_at != "preparing" :
87+ binds = await self .initialize ()
88+ logger .success ("Database initialized!" )
5589 if self .create_table_at == "prepared" :
56- async with self .db .engine .begin () as conn :
57- await conn .run_sync (self .base_class .metadata .create_all )
58- logger .success ("Database tables created!" )
90+ for model , engine in binds .items ():
91+ async with engine .begin () as conn :
92+ await conn .run_sync (model .__table__ .create , checkfirst = True )
93+ logger .success ("Database tables created!" )
5994
6095 async with self .stage ("blocking" ):
6196 if self .create_table_at == "blocking" :
62- async with self .db .engine .begin () as conn :
63- await conn .run_sync (self .base_class .metadata .create_all )
64- logger .success ("Database tables created!" )
97+ for model , engine in binds .items ():
98+ async with engine .begin () as conn :
99+ await conn .run_sync (model .__table__ .create , checkfirst = True )
100+ logger .success ("Database tables created!" )
65101 await manager .status .wait_for_sigexit ()
66102 async with self .stage ("cleanup" ):
67- await self .db .stop ()
103+ for engine in self .engines .values ():
104+ await engine .dispose (close = True )
68105
69106 async def execute (self , sql : Executable ) -> Result :
70- return await self .db .execute (sql )
107+ """执行 SQL 命令"""
108+ async with self .get_session () as session :
109+ return await session .execute (sql )
71110
72111 async def select_all (self , sql : TypedReturnsRows [tuple [T_Row ]]) -> Sequence [T_Row ]:
73- return await self .db .select_all (sql )
112+ async with self .get_session () as session :
113+ result = await session .scalars (sql )
114+ return result .all ()
74115
75116 async def select_first (self , sql : TypedReturnsRows [tuple [T_Row ]]) -> T_Row | None :
76- return await self .db .select_first (sql )
117+ async with self .get_session () as session :
118+ result = await session .scalars (sql )
119+ return cast ("T_Row | None" , result .first ())
77120
78- async def add (self , row : Base ):
79- return await self .db .add (row )
121+ async def add (self , row : Base ) -> None :
122+ async with self .get_session () as session :
123+ session .add (row )
124+ await session .commit ()
125+ await session .refresh (row )
80126
81127 async def add_many (self , rows : Sequence [Base ]):
82- return await self .db .add_many (rows )
128+ async with self .get_session () as session :
129+ session .add_all (rows )
130+ await session .commit ()
131+ for row in rows :
132+ await session .refresh (row )
83133
84134 async def update_or_add (self , row : Base ):
85- return await self .db .update_or_add (row )
135+ async with self .get_session () as session :
136+ await session .merge (row )
137+ await session .commit ()
138+ await session .refresh (row )
86139
87140 async def delete_exist (self , row : Base ):
88- return await self .db .delete_exist (row )
141+ async with self .get_session () as session :
142+ await session .delete (row )
89143
90144 async def delete_many_exist (self , rows : Sequence [Base ]):
91- return await self .db .delete_many_exist (rows )
145+ async with self .get_session () as session :
146+ for row in rows :
147+ await session .delete (row )
148+
0 commit comments