Skip to content

Commit cc94039

Browse files
committed
feat: support multi engine
1 parent d87a8f3 commit cc94039

File tree

5 files changed

+149
-128
lines changed

5 files changed

+149
-128
lines changed
Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,38 @@
11
from __future__ import annotations
22

3+
import logging
4+
5+
from loguru import logger
6+
37
try:
4-
import sqlalchemy as sa
8+
from sqlalchemy.log import Identified, _qual_logger_name_for_cls
59
except ImportError:
610
raise ImportError(
711
"dependency 'sqlalchemy' is required for sqlalchemy service\nplease install it or install 'graia-amnesia[sqla]'"
812
)
913

1014
from .model import Base as Base
1115
from .service import SqlalchemyService as SqlalchemyService
16+
from .utils import LoguruHandler, get_subclasses
17+
18+
19+
def pacth_logger(log_level: str | int = "INFO", sqlalchemy_echo: bool = False) -> None:
20+
handler = LoguruHandler()
21+
logging.getLogger("sqlalchemy").addHandler(handler)
22+
23+
if isinstance(log_level, str):
24+
log_level = logger.level(log_level).no
25+
26+
echo_log_level = log_level if sqlalchemy_echo else logging.WARNING
27+
28+
levels = {
29+
"alembic": log_level,
30+
"sqlalchemy": log_level,
31+
**{
32+
_qual_logger_name_for_cls(cls): echo_log_level
33+
for cls in set(get_subclasses(Identified))
34+
},
35+
}
36+
37+
for name, level in levels.items():
38+
logging.getLogger(name).setLevel(level)

src/graia/amnesia/builtins/sqla/manager.py

Lines changed: 0 additions & 95 deletions
This file was deleted.

src/graia/amnesia/builtins/sqla/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, cast
1+
from typing import Any, Callable, TYPE_CHECKING, ClassVar
22

33
from sqlalchemy import MetaData
44
from sqlalchemy.ext.asyncio import AsyncAttrs
@@ -20,7 +20,7 @@ def _setup_bind(cls: type["Base"]) -> None:
2020
if bind_key is None:
2121
bind_key = ""
2222

23-
cast(Table, cls.__table__).info["bind_key"] = bind_key
23+
cls.__table__.info["bind_key"] = bind_key
2424

2525

2626
_callbacks = []
@@ -46,6 +46,9 @@ class Base(AsyncAttrs, DeclarativeBase):
4646
__abstract__ = True
4747
metadata = MetaData(naming_convention=_NAMING_CONVENTION)
4848

49+
if TYPE_CHECKING:
50+
__table__: ClassVar[Table] # type: ignore
51+
4952
def __init_subclass__(cls, **kwargs):
5053
for callback in _callbacks:
5154
callback(cls)
Lines changed: 87 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,48 @@
11
from __future__ import annotations
22

33
from collections.abc import Sequence
4-
from typing import Literal, ClassVar
4+
from typing import Literal, ClassVar, Any, TypeVar, cast
55

66
from launart import Launart, Service
77
from loguru import logger
8-
from sqlalchemy.engine.result import Result
9-
from sqlalchemy.engine.url import URL
10-
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
8+
from sqlalchemy import Table
9+
10+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
11+
from sqlalchemy.ext.asyncio.engine import AsyncEngine
1112
from sqlalchemy.sql.base import Executable
1213
from sqlalchemy.sql.selectable import TypedReturnsRows
14+
from sqlalchemy.engine.result import Result
15+
from sqlalchemy.engine.url import URL
1316
from sqlalchemy.orm import DeclarativeBase
1417

15-
from .manager import DatabaseManager, T_Row
1618
from .model import Base
1719
from .types import EngineOptions
20+
from .utils import get_subclasses
21+
22+
T_Row = TypeVar("T_Row", bound=DeclarativeBase)
1823

1924

2025
class SqlalchemyService(Service):
2126
id: str = "database/sqlalchemy"
22-
db: DatabaseManager
23-
get_session: async_sessionmaker[AsyncSession]
2427
base_class: ClassVar[type[DeclarativeBase]] = Base
28+
engines: dict[str, AsyncEngine]
29+
session_factory: async_sessionmaker[AsyncSession]
2530

2631
def __init__(
2732
self,
2833
url: str | URL,
2934
engine_options: EngineOptions | None = None,
35+
session_options: dict[str, Any] | None = None,
36+
binds: dict[str, str | URL] | None = None,
3037
create_table_at: Literal["preparing", "prepared", "blocking"] = "preparing"
3138
) -> None:
32-
self.db = DatabaseManager(url, engine_options)
39+
if engine_options is None:
40+
engine_options = {"echo": "debug", "pool_pre_ping": True}
41+
self.engines[""] = create_async_engine(url, **engine_options)
42+
for key, bind_url in (binds or {}).items():
43+
self.engines[key] = create_async_engine(bind_url, **engine_options)
3344
self.create_table_at = create_table_at
45+
self.session_options = session_options or {"expire_on_commit": False}
3446
super().__init__()
3547

3648
@property
@@ -41,51 +53,96 @@ def required(self) -> set[str]:
4153
def stages(self) -> set[Literal["preparing", "blocking", "cleanup"]]:
4254
return {"preparing", "blocking", "cleanup"}
4355

56+
async def initialize(self):
57+
binds = {}
58+
59+
for model in set(get_subclasses(self.base_class)):
60+
table: Table | None = getattr(model, "__table__", None)
61+
62+
if table is None or (bind_key := table.info.get("bind_key")) is None:
63+
continue
64+
65+
binds[model] = self.engines.get(bind_key, self.engines[""])
66+
67+
self.session_factory = async_sessionmaker(self.engines[""], binds=binds, **self.session_options)
68+
return binds
69+
70+
def get_session(self, **local_kw):
71+
return self.session_factory(**local_kw)
72+
4473
async def launch(self, manager: Launart):
74+
binds: dict[type[Base], AsyncEngine] = {}
75+
4576
async with self.stage("preparing"):
4677
logger.info("Initializing database...")
47-
await self.db.initialize()
48-
self.get_session = self.db.session_factory
49-
logger.success("Database initialized!")
5078
if self.create_table_at == "preparing":
51-
async with self.db.engine.begin() as conn:
52-
await conn.run_sync(self.base_class.metadata.create_all)
53-
logger.success("Database tables created!")
79+
binds = await self.initialize()
80+
logger.success("Database initialized!")
81+
for model, engine in binds.items():
82+
async with engine.begin() as conn:
83+
await conn.run_sync(model.__table__.create, checkfirst=True)
84+
logger.success("Database tables created!")
5485

86+
if self.create_table_at != "preparing":
87+
binds = await self.initialize()
88+
logger.success("Database initialized!")
5589
if self.create_table_at == "prepared":
56-
async with self.db.engine.begin() as conn:
57-
await conn.run_sync(self.base_class.metadata.create_all)
58-
logger.success("Database tables created!")
90+
for model, engine in binds.items():
91+
async with engine.begin() as conn:
92+
await conn.run_sync(model.__table__.create, checkfirst=True)
93+
logger.success("Database tables created!")
5994

6095
async with self.stage("blocking"):
6196
if self.create_table_at == "blocking":
62-
async with self.db.engine.begin() as conn:
63-
await conn.run_sync(self.base_class.metadata.create_all)
64-
logger.success("Database tables created!")
97+
for model, engine in binds.items():
98+
async with engine.begin() as conn:
99+
await conn.run_sync(model.__table__.create, checkfirst=True)
100+
logger.success("Database tables created!")
65101
await manager.status.wait_for_sigexit()
66102
async with self.stage("cleanup"):
67-
await self.db.stop()
103+
for engine in self.engines.values():
104+
await engine.dispose(close=True)
68105

69106
async def execute(self, sql: Executable) -> Result:
70-
return await self.db.execute(sql)
107+
"""执行 SQL 命令"""
108+
async with self.get_session() as session:
109+
return await session.execute(sql)
71110

72111
async def select_all(self, sql: TypedReturnsRows[tuple[T_Row]]) -> Sequence[T_Row]:
73-
return await self.db.select_all(sql)
112+
async with self.get_session() as session:
113+
result = await session.scalars(sql)
114+
return result.all()
74115

75116
async def select_first(self, sql: TypedReturnsRows[tuple[T_Row]]) -> T_Row | None:
76-
return await self.db.select_first(sql)
117+
async with self.get_session() as session:
118+
result = await session.scalars(sql)
119+
return cast("T_Row | None", result.first())
77120

78-
async def add(self, row: Base):
79-
return await self.db.add(row)
121+
async def add(self, row: Base) -> None:
122+
async with self.get_session() as session:
123+
session.add(row)
124+
await session.commit()
125+
await session.refresh(row)
80126

81127
async def add_many(self, rows: Sequence[Base]):
82-
return await self.db.add_many(rows)
128+
async with self.get_session() as session:
129+
session.add_all(rows)
130+
await session.commit()
131+
for row in rows:
132+
await session.refresh(row)
83133

84134
async def update_or_add(self, row: Base):
85-
return await self.db.update_or_add(row)
135+
async with self.get_session() as session:
136+
await session.merge(row)
137+
await session.commit()
138+
await session.refresh(row)
86139

87140
async def delete_exist(self, row: Base):
88-
return await self.db.delete_exist(row)
141+
async with self.get_session() as session:
142+
await session.delete(row)
89143

90144
async def delete_many_exist(self, rows: Sequence[Base]):
91-
return await self.db.delete_many_exist(rows)
145+
async with self.get_session() as session:
146+
for row in rows:
147+
await session.delete(row)
148+
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import logging
2+
import sys
3+
4+
from loguru import logger
5+
6+
7+
class LoguruHandler(logging.Handler):
8+
def emit(self, record: logging.LogRecord):
9+
try:
10+
level = logger.level(record.levelname).name
11+
if record.levelno <= logging.INFO:
12+
level = {"DEBUG": "TRACE", "INFO": "DEBUG"}.get(level, level)
13+
except ValueError:
14+
level = record.levelno
15+
16+
frame, depth = sys._getframe(6), 6
17+
while frame and frame.f_code.co_filename == logging.__file__:
18+
frame = frame.f_back
19+
depth += 1
20+
21+
logger.opt(depth=depth, exception=record.exc_info).log(
22+
level, record.getMessage()
23+
)
24+
25+
26+
def get_subclasses(cls):
27+
yield from cls.__subclasses__()
28+
for subclass in cls.__subclasses__():
29+
yield from get_subclasses(subclass)

0 commit comments

Comments
 (0)