Skip to content

Commit d87a8f3

Browse files
committed
feat: allow callbacks when SQLA modelbase init subclass
1 parent 24adee6 commit d87a8f3

File tree

7 files changed

+138
-61
lines changed

7 files changed

+138
-61
lines changed

pdm.lock

Lines changed: 69 additions & 50 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/graia/amnesia/builtins/aiohttp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from typing import cast
44

55
from launart import Launart, Service
6+
from launart.status import Phase
67

78
try:
89
from aiohttp import ClientSession, ClientTimeout
910
except ImportError:
1011
raise ImportError(
11-
"dependency 'aiohttp' is required for aiohttp client service\nplease install it or install 'graia-amnesia[aiohttp]'"
12+
"dependency 'aiohttp' is required for aiohttp client service\n"
13+
"please install it or install 'graia-amnesia[aiohttp]'"
1214
)
1315

1416

@@ -21,14 +23,14 @@ def __init__(self, session: ClientSession | None = None) -> None:
2123
super().__init__()
2224

2325
@property
24-
def stages(self):
26+
def stages(self) -> set[Phase]:
2527
return {"preparing", "cleanup"}
2628

2729
@property
2830
def required(self):
2931
return set()
3032

31-
async def launch(self, _: Launart):
33+
async def launch(self, manager: Launart):
3234
async with self.stage("preparing"):
3335
if self.session is None:
3436
self.session = ClientSession(timeout=ClientTimeout(total=None))

src/graia/amnesia/builtins/asgi/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55

66
from launart import Launart, Service
7+
from launart.status import Phase
78
from launart.utilles import any_completed
89
from loguru import logger
910

@@ -85,7 +86,7 @@ def required(self):
8586
return set()
8687

8788
@property
88-
def stages(self):
89+
def stages(self) -> set[Phase]:
8990
return {"preparing", "blocking", "cleanup"}
9091

9192
async def launch(self, manager: Launart) -> None:

src/graia/amnesia/builtins/httpx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import cast
44

55
from launart import Launart, Service
6+
from launart.status import Phase
67

78
try:
89
from httpx import AsyncClient, Timeout
@@ -21,14 +22,14 @@ def __init__(self, session: AsyncClient | None = None) -> None:
2122
super().__init__()
2223

2324
@property
24-
def stages(self):
25+
def stages(self) -> set[Phase]:
2526
return {"preparing", "cleanup"}
2627

2728
@property
2829
def required(self):
2930
return set()
3031

31-
async def launch(self, _: Launart):
32+
async def launch(self, manager: Launart):
3233
async with self.stage("preparing"):
3334
if self.session is None:
3435
self.session = AsyncClient(timeout=Timeout())

src/graia/amnesia/builtins/memcache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88

99
from launart import Launart, Service
10+
from launart.status import Phase
1011

1112

1213
class Memcache:
@@ -75,7 +76,7 @@ def required(self):
7576
return set()
7677

7778
@property
78-
def stages(self):
79+
def stages(self) -> set[Phase]:
7980
return {"blocking"}
8081

8182
@property

src/graia/amnesia/builtins/sqla/model.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from typing import Any, Callable, cast
2+
13
from sqlalchemy import MetaData
24
from sqlalchemy.ext.asyncio import AsyncAttrs
35
from sqlalchemy.orm import DeclarativeBase
6+
from sqlalchemy.schema import Table
47

58
_NAMING_CONVENTION = {
69
"ix": "ix_%(column_0_label)s",
@@ -11,6 +14,54 @@
1114
}
1215

1316

17+
def _setup_bind(cls: type["Base"]) -> None:
18+
bind_key: str | None = getattr(cls, "__bind_key__", None)
19+
20+
if bind_key is None:
21+
bind_key = ""
22+
23+
cast(Table, cls.__table__).info["bind_key"] = bind_key
24+
25+
26+
_callbacks = []
27+
28+
29+
def register_callback(callback: Callable[[type["Base"]], Any]) -> None:
30+
"""
31+
Register a callback to be called when a new Base subclass is created.
32+
The callback should accept a single argument, which is the subclass itself.
33+
"""
34+
_callbacks.append(callback)
35+
36+
37+
def remove_callback(callback: Callable[[type["Base"]], Any]) -> None:
38+
"""
39+
Remove a previously registered callback.
40+
"""
41+
if callback in _callbacks:
42+
_callbacks.remove(callback)
43+
44+
1445
class Base(AsyncAttrs, DeclarativeBase):
1546
__abstract__ = True
1647
metadata = MetaData(naming_convention=_NAMING_CONVENTION)
48+
49+
def __init_subclass__(cls, **kwargs):
50+
for callback in _callbacks:
51+
callback(cls)
52+
53+
if not hasattr(cls, "__tablename__") and "tablename" in kwargs:
54+
cls.__tablename__ = kwargs["tablename"]
55+
if not hasattr(cls, "__table_args__") and "table_args" in kwargs:
56+
cls.__table_args__ = kwargs["table_args"]
57+
if not hasattr(cls, "__mapper__") and "mapper" in kwargs:
58+
cls.__mapper__ = kwargs["mapper"]
59+
if not hasattr(cls, "__mapper_args__") and "mapper_args" in kwargs:
60+
cls.__mapping_args__ = kwargs["mapper_args"]
61+
62+
super().__init_subclass__(**kwargs)
63+
64+
if not hasattr(cls, "__table__"):
65+
return
66+
67+
_setup_bind(cls)

src/graia/amnesia/builtins/sqla/service.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Sequence
4-
from typing import Literal
4+
from typing import Literal, ClassVar
55

66
from launart import Launart, Service
77
from loguru import logger
@@ -10,6 +10,7 @@
1010
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
1111
from sqlalchemy.sql.base import Executable
1212
from sqlalchemy.sql.selectable import TypedReturnsRows
13+
from sqlalchemy.orm import DeclarativeBase
1314

1415
from .manager import DatabaseManager, T_Row
1516
from .model import Base
@@ -20,6 +21,7 @@ class SqlalchemyService(Service):
2021
id: str = "database/sqlalchemy"
2122
db: DatabaseManager
2223
get_session: async_sessionmaker[AsyncSession]
24+
base_class: ClassVar[type[DeclarativeBase]] = Base
2325

2426
def __init__(
2527
self,
@@ -47,18 +49,18 @@ async def launch(self, manager: Launart):
4749
logger.success("Database initialized!")
4850
if self.create_table_at == "preparing":
4951
async with self.db.engine.begin() as conn:
50-
await conn.run_sync(Base.metadata.create_all)
52+
await conn.run_sync(self.base_class.metadata.create_all)
5153
logger.success("Database tables created!")
5254

5355
if self.create_table_at == "prepared":
5456
async with self.db.engine.begin() as conn:
55-
await conn.run_sync(Base.metadata.create_all)
57+
await conn.run_sync(self.base_class.metadata.create_all)
5658
logger.success("Database tables created!")
5759

5860
async with self.stage("blocking"):
5961
if self.create_table_at == "blocking":
6062
async with self.db.engine.begin() as conn:
61-
await conn.run_sync(Base.metadata.create_all)
63+
await conn.run_sync(self.base_class.metadata.create_all)
6264
logger.success("Database tables created!")
6365
await manager.status.wait_for_sigexit()
6466
async with self.stage("cleanup"):

0 commit comments

Comments
 (0)