Skip to content

Commit 3c962a2

Browse files
committed
chore: format
1 parent cc94039 commit 3c962a2

File tree

8 files changed

+69
-57
lines changed

8 files changed

+69
-57
lines changed

pdm.lock

Lines changed: 30 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ includes = ["src/graia"]
5353

5454
[tool.pdm.dev-dependencies]
5555
dev = [
56-
"black<23.0.0,>=22.1.0",
56+
"black>=25.0.0",
5757
"uvicorn>=0.23.2",
5858
"aiohttp>=3.9.1",
5959
"httpx>=0.26.0",
6060
"sqlalchemy>=2.0.25",
61-
"isort>=5.13.2",
61+
"isort==5.13.2",
6262
"pytest>=7.4.4",
6363
]
6464

src/graia/amnesia/builtins/asgi/asgitypes.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,9 @@ class LifespanShutdownFailedEvent(TypedDict):
238238

239239

240240
class ASGI2Protocol(Protocol):
241-
def __init__(self, scope: Scope) -> None:
242-
...
241+
def __init__(self, scope: Scope) -> None: ...
243242

244-
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
245-
...
243+
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ...
246244

247245

248246
ASGI2Application = Type[ASGI2Protocol]

src/graia/amnesia/builtins/sqla/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ def pacth_logger(log_level: str | int = "INFO", sqlalchemy_echo: bool = False) -
2828
levels = {
2929
"alembic": log_level,
3030
"sqlalchemy": log_level,
31-
**{
32-
_qual_logger_name_for_cls(cls): echo_log_level
33-
for cls in set(get_subclasses(Identified))
34-
},
31+
**{_qual_logger_name_for_cls(cls): echo_log_level for cls in set(get_subclasses(Identified))},
3532
}
3633

3734
for name, level in levels.items():

src/graia/amnesia/builtins/sqla/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, TYPE_CHECKING, ClassVar
1+
from typing import TYPE_CHECKING, Any, Callable, ClassVar
22

33
from sqlalchemy import MetaData
44
from sqlalchemy.ext.asyncio import AsyncAttrs

src/graia/amnesia/builtins/sqla/service.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
from __future__ import annotations
22

33
from collections.abc import Sequence
4-
from typing import Literal, ClassVar, Any, TypeVar, cast
4+
from typing import Any, ClassVar, Literal, TypeVar, cast
55

66
from launart import Launart, Service
77
from loguru import logger
88
from sqlalchemy import Table
9-
9+
from sqlalchemy.engine.result import Result
10+
from sqlalchemy.engine.url import URL
1011
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
1112
from sqlalchemy.ext.asyncio.engine import AsyncEngine
13+
from sqlalchemy.orm import DeclarativeBase
1214
from sqlalchemy.sql.base import Executable
1315
from sqlalchemy.sql.selectable import TypedReturnsRows
14-
from sqlalchemy.engine.result import Result
15-
from sqlalchemy.engine.url import URL
16-
from sqlalchemy.orm import DeclarativeBase
1716

1817
from .model import Base
1918
from .types import EngineOptions
@@ -34,7 +33,7 @@ def __init__(
3433
engine_options: EngineOptions | None = None,
3534
session_options: dict[str, Any] | None = None,
3635
binds: dict[str, str | URL] | None = None,
37-
create_table_at: Literal["preparing", "prepared", "blocking"] = "preparing"
36+
create_table_at: Literal["preparing", "prepared", "blocking"] = "preparing",
3837
) -> None:
3938
if engine_options is None:
4039
engine_options = {"echo": "debug", "pool_pre_ping": True}
@@ -54,6 +53,7 @@ def stages(self) -> set[Literal["preparing", "blocking", "cleanup"]]:
5453
return {"preparing", "blocking", "cleanup"}
5554

5655
async def initialize(self):
56+
_binds = {}
5757
binds = {}
5858

5959
for model in set(get_subclasses(self.base_class)):
@@ -62,41 +62,52 @@ async def initialize(self):
6262
if table is None or (bind_key := table.info.get("bind_key")) is None:
6363
continue
6464

65-
binds[model] = self.engines.get(bind_key, self.engines[""])
65+
if bind_key in self.engines:
66+
_binds[model] = self.engines[bind_key]
67+
binds.setdefault(bind_key, []).append(model)
68+
else:
69+
_binds[model] = self.engines[""]
70+
binds.setdefault("", []).append(model)
6671

67-
self.session_factory = async_sessionmaker(self.engines[""], binds=binds, **self.session_options)
72+
self.session_factory = async_sessionmaker(self.engines[""], binds=_binds, **self.session_options)
6873
return binds
6974

7075
def get_session(self, **local_kw):
7176
return self.session_factory(**local_kw)
7277

7378
async def launch(self, manager: Launart):
74-
binds: dict[type[Base], AsyncEngine] = {}
79+
binds: dict[str, list[type[Base]]] = {}
7580

7681
async with self.stage("preparing"):
7782
logger.info("Initializing database...")
7883
if self.create_table_at == "preparing":
7984
binds = await self.initialize()
8085
logger.success("Database initialized!")
81-
for model, engine in binds.items():
82-
async with engine.begin() as conn:
83-
await conn.run_sync(model.__table__.create, checkfirst=True)
86+
for key, models in binds.items():
87+
async with self.engines[key].begin() as conn:
88+
await conn.run_sync(
89+
self.base_class.metadata.create_all, tables=[m.__table__ for m in models], checkfirst=True
90+
)
8491
logger.success("Database tables created!")
8592

8693
if self.create_table_at != "preparing":
8794
binds = await self.initialize()
8895
logger.success("Database initialized!")
8996
if self.create_table_at == "prepared":
90-
for model, engine in binds.items():
91-
async with engine.begin() as conn:
92-
await conn.run_sync(model.__table__.create, checkfirst=True)
97+
for key, models in binds.items():
98+
async with self.engines[key].begin() as conn:
99+
await conn.run_sync(
100+
self.base_class.metadata.create_all, tables=[m.__table__ for m in models], checkfirst=True
101+
)
93102
logger.success("Database tables created!")
94103

95104
async with self.stage("blocking"):
96105
if self.create_table_at == "blocking":
97-
for model, engine in binds.items():
98-
async with engine.begin() as conn:
99-
await conn.run_sync(model.__table__.create, checkfirst=True)
106+
for key, models in binds.items():
107+
async with self.engines[key].begin() as conn:
108+
await conn.run_sync(
109+
self.base_class.metadata.create_all, tables=[m.__table__ for m in models], checkfirst=True
110+
)
100111
logger.success("Database tables created!")
101112
await manager.status.wait_for_sigexit()
102113
async with self.stage("cleanup"):
@@ -145,4 +156,3 @@ async def delete_many_exist(self, rows: Sequence[Base]):
145156
async with self.get_session() as session:
146157
for row in rows:
147158
await session.delete(row)
148-

src/graia/amnesia/builtins/sqla/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ def emit(self, record: logging.LogRecord):
1818
frame = frame.f_back
1919
depth += 1
2020

21-
logger.opt(depth=depth, exception=record.exc_info).log(
22-
level, record.getMessage()
23-
)
21+
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
2422

2523

2624
def get_subclasses(cls):

src/graia/amnesia/message/chain.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,13 @@ def join(self, *chains: Self | Iterable[Self]) -> Self:
174174
__contains__ = has
175175

176176
@overload
177-
def __getitem__(self, item: type[E]) -> list[E]:
178-
...
177+
def __getitem__(self, item: type[E]) -> list[E]: ...
179178

180179
@overload
181-
def __getitem__(self, item: int) -> Element:
182-
...
180+
def __getitem__(self, item: int) -> Element: ...
183181

184182
@overload
185-
def __getitem__(self, item: slice) -> Self:
186-
...
183+
def __getitem__(self, item: slice) -> Self: ...
187184

188185
def __getitem__(self, item: type[Element] | int | slice) -> Any:
189186
"""取出子消息链, 或元素.

0 commit comments

Comments
 (0)