Skip to content

Commit 7b455d4

Browse files
committed
🎨 refactor(sqla): move param init into __init__.py
1 parent 88b8847 commit 7b455d4

File tree

3 files changed

+35
-44
lines changed

3 files changed

+35
-44
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,29 @@
22

33
import sys
44
import logging
5+
from typing import Any
56
from asyncio import gather
67
from operator import methodcaller
7-
from typing import Any, AsyncGenerator
8+
from collections.abc import AsyncGenerator
89
from functools import wraps, partial, lru_cache
910

10-
from nonebot.params import Depends
11+
import click
12+
from nonebot.rule import Rule
1113
import sqlalchemy.ext.asyncio as sa_async
14+
from nonebot.permission import Permission
1215
from sqlalchemy.util import greenlet_spawn
13-
from nonebot.matcher import current_matcher
16+
from nonebot.params import Depends, DefaultParam
1417
from nonebot.plugin import Plugin, PluginMetadata
18+
from nonebot.matcher import Matcher, current_matcher
1519
from sqlalchemy import URL, Table, MetaData, make_url
20+
from sqlalchemy.log import Identified, _qual_logger_name_for_cls
1621
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
1722
from nonebot import logger, require, get_driver, get_plugin_by_module_name
1823

1924
from . import migrate
20-
from .config import Config
21-
from .utils import LoguruHandler, StreamToLogger
25+
from .param import ORMParam
26+
from .config import Config, plugin_config
27+
from .utils import LoguruHandler, StreamToLogger, get_subclasses
2228

2329
if sys.version_info >= (3, 10):
2430
from typing import Annotated
@@ -36,28 +42,11 @@
3642
"Model",
3743
# param
3844
"SQLDepends",
39-
"ORMParam",
4045
# config
4146
"Config",
42-
"config",
47+
"plugin_config",
4348
# migrate
4449
"AlembicConfig",
45-
"list_templates",
46-
"init",
47-
"revision",
48-
"check",
49-
"merge",
50-
"upgrade",
51-
"downgrade",
52-
"sync",
53-
"show",
54-
"history",
55-
"heads",
56-
"branches",
57-
"current",
58-
"stamp",
59-
"edit",
60-
"ensure_version",
6150
)
6251
__plugin_meta__ = PluginMetadata(
6352
name="nonebot-plugin-orm",
@@ -68,7 +57,6 @@
6857
config=Config,
6958
)
7059

71-
7260
_binds: dict[type[Model], AsyncEngine]
7361
_engines: dict[str, AsyncEngine]
7462
_metadatas: dict[str, MetaData]
@@ -230,9 +218,24 @@ def _init_logger():
230218
l.setLevel(level)
231219

232220

233-
from .model import *
234-
from .param import *
235-
from .config import *
236-
from .migrate import *
237-
238221
_init_logger()
222+
223+
224+
def _init_param():
225+
for cls in (Rule, Permission):
226+
cls.HANDLER_PARAM_TYPES.insert(-1, ORMParam)
227+
228+
Matcher.HANDLER_PARAM_TYPES = Matcher.HANDLER_PARAM_TYPES[:-1] + (
229+
ORMParam,
230+
DefaultParam,
231+
)
232+
233+
234+
_init_param()
235+
236+
237+
from .model import Model as Model
238+
from .config import Config as Config
239+
from .param import SQLDepends as SQLDepends
240+
from .config import plugin_config as plugin_config
241+
from .migrate import AlembicConfig as AlembicConfig

nonebot_plugin_orm/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
__all__ = ("Model",)
2020

2121

22-
NAMING_CONVENTION = {
22+
_NAMING_CONVENTION = {
2323
"ix": "ix_%(column_0_label)s",
2424
"uq": "uq_%(table_name)s_%(column_0_name)s",
2525
"ck": "ck_%(table_name)s_%(constraint_name)s",
@@ -29,7 +29,7 @@
2929

3030

3131
class Model(DeclarativeBase):
32-
metadata = MetaData(naming_convention=NAMING_CONVENTION)
32+
metadata = MetaData(naming_convention=_NAMING_CONVENTION)
3333

3434
if TYPE_CHECKING:
3535
__bind_key__: ClassVar[str]

nonebot_plugin_orm/param.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,10 @@
77
from typing_extensions import Annotated
88
from typing import Any, Tuple, Iterator, Sequence, AsyncIterator, cast
99

10-
from nonebot.rule import Rule
11-
from nonebot.matcher import Matcher
1210
from pydantic.fields import FieldInfo
1311
from nonebot.dependencies import Param
14-
from nonebot.permission import Permission
12+
from nonebot.params import DependParam
1513
from pydantic.typing import get_args, get_origin
16-
from nonebot.params import DependParam, DefaultParam
1714
from sqlalchemy import Row, Result, ScalarResult, select
1815
from sqlalchemy.sql.selectable import ExecutableReturnsRows
1916
from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult
@@ -179,12 +176,3 @@ def _check_param(
179176
@classmethod
180177
def _check_parameterless(cls, *_) -> Param | None:
181178
return
182-
183-
184-
for cls in (Rule, Permission):
185-
cls.HANDLER_PARAM_TYPES.insert(-1, ORMParam)
186-
187-
Matcher.HANDLER_PARAM_TYPES = Matcher.HANDLER_PARAM_TYPES[:-1] + (
188-
ORMParam,
189-
DefaultParam,
190-
)

0 commit comments

Comments
 (0)