11from __future__ import annotations
22
33from collections .abc import Sequence
4- from typing import Literal , ClassVar , Any , TypeVar , cast
4+ from typing import Any , ClassVar , Literal , TypeVar , cast
55
66from launart import Launart , Service
77from loguru import logger
88from sqlalchemy import Table
9-
9+ from sqlalchemy .engine .result import Result
10+ from sqlalchemy .engine .url import URL
1011from sqlalchemy .ext .asyncio import AsyncSession , async_sessionmaker , create_async_engine
1112from sqlalchemy .ext .asyncio .engine import AsyncEngine
13+ from sqlalchemy .orm import DeclarativeBase
1214from sqlalchemy .sql .base import Executable
1315from sqlalchemy .sql .selectable import TypedReturnsRows
14- from sqlalchemy .engine .result import Result
15- from sqlalchemy .engine .url import URL
16- from sqlalchemy .orm import DeclarativeBase
1716
1817from .model import Base
1918from .types import EngineOptions
@@ -34,7 +33,7 @@ def __init__(
3433 engine_options : EngineOptions | None = None ,
3534 session_options : dict [str , Any ] | None = None ,
3635 binds : dict [str , str | URL ] | None = None ,
37- create_table_at : Literal ["preparing" , "prepared" , "blocking" ] = "preparing"
36+ create_table_at : Literal ["preparing" , "prepared" , "blocking" ] = "preparing" ,
3837 ) -> None :
3938 if engine_options is None :
4039 engine_options = {"echo" : "debug" , "pool_pre_ping" : True }
@@ -54,6 +53,7 @@ def stages(self) -> set[Literal["preparing", "blocking", "cleanup"]]:
5453 return {"preparing" , "blocking" , "cleanup" }
5554
5655 async def initialize (self ):
56+ _binds = {}
5757 binds = {}
5858
5959 for model in set (get_subclasses (self .base_class )):
@@ -62,41 +62,52 @@ async def initialize(self):
6262 if table is None or (bind_key := table .info .get ("bind_key" )) is None :
6363 continue
6464
65- binds [model ] = self .engines .get (bind_key , self .engines ["" ])
65+ if bind_key in self .engines :
66+ _binds [model ] = self .engines [bind_key ]
67+ binds .setdefault (bind_key , []).append (model )
68+ else :
69+ _binds [model ] = self .engines ["" ]
70+ binds .setdefault ("" , []).append (model )
6671
67- self .session_factory = async_sessionmaker (self .engines ["" ], binds = binds , ** self .session_options )
72+ self .session_factory = async_sessionmaker (self .engines ["" ], binds = _binds , ** self .session_options )
6873 return binds
6974
7075 def get_session (self , ** local_kw ):
7176 return self .session_factory (** local_kw )
7277
7378 async def launch (self , manager : Launart ):
74- binds : dict [type [Base ], AsyncEngine ] = {}
79+ binds : dict [str , list [ type [Base ]] ] = {}
7580
7681 async with self .stage ("preparing" ):
7782 logger .info ("Initializing database..." )
7883 if self .create_table_at == "preparing" :
7984 binds = await self .initialize ()
8085 logger .success ("Database initialized!" )
81- for model , engine in binds .items ():
82- async with engine .begin () as conn :
83- await conn .run_sync (model .__table__ .create , checkfirst = True )
86+ for key , models in binds .items ():
87+ async with self .engines [key ].begin () as conn :
88+ await conn .run_sync (
89+ self .base_class .metadata .create_all , tables = [m .__table__ for m in models ], checkfirst = True
90+ )
8491 logger .success ("Database tables created!" )
8592
8693 if self .create_table_at != "preparing" :
8794 binds = await self .initialize ()
8895 logger .success ("Database initialized!" )
8996 if self .create_table_at == "prepared" :
90- for model , engine in binds .items ():
91- async with engine .begin () as conn :
92- await conn .run_sync (model .__table__ .create , checkfirst = True )
97+ for key , models in binds .items ():
98+ async with self .engines [key ].begin () as conn :
99+ await conn .run_sync (
100+ self .base_class .metadata .create_all , tables = [m .__table__ for m in models ], checkfirst = True
101+ )
93102 logger .success ("Database tables created!" )
94103
95104 async with self .stage ("blocking" ):
96105 if self .create_table_at == "blocking" :
97- for model , engine in binds .items ():
98- async with engine .begin () as conn :
99- await conn .run_sync (model .__table__ .create , checkfirst = True )
106+ for key , models in binds .items ():
107+ async with self .engines [key ].begin () as conn :
108+ await conn .run_sync (
109+ self .base_class .metadata .create_all , tables = [m .__table__ for m in models ], checkfirst = True
110+ )
100111 logger .success ("Database tables created!" )
101112 await manager .status .wait_for_sigexit ()
102113 async with self .stage ("cleanup" ):
@@ -145,4 +156,3 @@ async def delete_many_exist(self, rows: Sequence[Base]):
145156 async with self .get_session () as session :
146157 for row in rows :
147158 await session .delete (row )
148-
0 commit comments