Skip to content

Commit ad5f8d5

Browse files
committed
♻️ refactor(sqla): Model di by ModelParam
1 parent d6806ff commit ad5f8d5

File tree

4 files changed

+203
-44
lines changed

4 files changed

+203
-44
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def _init_logger():
222222

223223
from .sql import *
224224
from .model import *
225+
from .param import *
225226
from .config import *
226227
from .migrate import *
227228

nonebot_plugin_orm/model.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
from inspect import Parameter, Signature
55
from typing import TYPE_CHECKING, Any, ClassVar
66

7-
from nonebot.params import Depends
7+
from sqlalchemy import Table, MetaData
88
from nonebot import get_plugin_by_module_name
9-
from sqlalchemy import Table, MetaData, select
109
from pydantic.typing import get_args, get_origin
1110
from sqlalchemy.orm import Mapped, DeclarativeBase
12-
from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept
1311

1412
from .utils import DependsInner, get_annotations
1513

@@ -30,47 +28,11 @@
3028
}
3129

3230

33-
class ModelMeta(DeclarativeAttributeIntercept):
34-
if TYPE_CHECKING:
35-
__signature__: Signature
36-
37-
def __new__(
38-
mcs,
39-
name: str,
40-
bases: tuple[type, ...],
41-
namespace: dict[str, Any],
42-
**kwargs: Any,
43-
) -> ModelMeta:
44-
from . import async_scoped_session
45-
46-
cls: ModelMeta = super().__new__(mcs, name, bases, namespace, **kwargs)
47-
48-
if not (signature := getattr(cls, "__signature__", None)):
49-
return cls
50-
51-
async def dependency(
52-
*, __session: async_scoped_session, **kwargs: Any
53-
) -> ModelMeta | None:
54-
return await __session.scalar(select(cls).filter_by(**kwargs))
55-
56-
dependency.__signature__ = Signature(
57-
(
58-
Parameter(
59-
"_ModelMeta__session",
60-
Parameter.KEYWORD_ONLY,
61-
annotation=async_scoped_session,
62-
),
63-
*signature.parameters.values(),
64-
)
65-
)
66-
67-
return Annotated[cls, Depends(dependency)]
68-
69-
70-
class Model(DeclarativeBase, metaclass=ModelMeta):
31+
class Model(DeclarativeBase):
7132
metadata = MetaData(naming_convention=NAMING_CONVENTION)
7233

7334
if TYPE_CHECKING:
35+
__signature__: Signature
7436
__table__: ClassVar[Table]
7537
__bind_key__: ClassVar[str]
7638

@@ -90,8 +52,6 @@ def _setup_di(cls: type[Model]) -> None:
9052
"""Get signature for NoneBot's dependency injection,
9153
and set annotations for SQLAlchemy declarative class.
9254
"""
93-
from . import async_scoped_session
94-
9555
parameters: list[Parameter] = []
9656

9757
annotations: dict[str, Any] = {}

nonebot_plugin_orm/param.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from typing import Any, Union
5+
from inspect import Parameter, isclass
6+
from typing_extensions import Annotated
7+
from collections.abc import Iterator, Sequence, AsyncIterator
8+
9+
from nonebot.matcher import Matcher
10+
from nonebot.dependencies import Param
11+
from pydantic.typing import get_args, get_origin
12+
from nonebot.params import DependParam, DefaultParam
13+
from sqlalchemy import Row, Result, ScalarResult, select
14+
from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult
15+
16+
from .model import Model
17+
from .utils import Option, toclass, methodcall, compile_dependency
18+
19+
if sys.version_info >= (3, 10):
20+
from types import NoneType, UnionType
21+
else:
22+
NoneType = type(None)
23+
UnionType = None
24+
25+
26+
def parse_model_annotation(
27+
anno: Any,
28+
) -> tuple[tuple[type[Model], ...], Option] | tuple[None, None]:
29+
if isclass(anno) and issubclass(anno, Model):
30+
return (anno,), Option(scalars=True, result=methodcall("one_or_none"))
31+
32+
origin, args = get_origin(anno), get_args(anno)
33+
34+
if not (origin and args):
35+
return (None, None)
36+
37+
if origin is Annotated:
38+
return parse_model_annotation(args[0])
39+
40+
if origin in (UnionType, Union) and len(args) == 2:
41+
if args[0] is NoneType:
42+
return parse_model_annotation(args[1])
43+
elif args[1] is NoneType:
44+
return parse_model_annotation(args[0])
45+
46+
if not isclass(origin):
47+
return (None, None)
48+
49+
if origin is Row:
50+
origin, args = tuple, get_args(args[0])
51+
52+
if origin is tuple and all(issubclass(arg, Model) for arg in map(toclass, args)):
53+
return args, Option(result=methodcall("one_or_none"))
54+
55+
models, option = parse_model_annotation(args[0])
56+
if not (models and option):
57+
return (None, None)
58+
59+
if option.result == methodcall("all"):
60+
if issubclass(Iterator, origin):
61+
return models, Option(False, option.scalars, methodcall("partitions"))
62+
63+
if issubclass(AsyncIterator, origin):
64+
return models, Option(True, option.scalars, methodcall("partitions"))
65+
66+
if option.result != methodcall("one_or_none"):
67+
return (None, None)
68+
69+
if (
70+
(not option.scalars and origin is Result)
71+
or (option.scalars and origin is ScalarResult)
72+
or issubclass(Iterator, origin)
73+
):
74+
return models, Option(False, option.scalars)
75+
76+
if (
77+
(not option.scalars and origin is AsyncResult)
78+
or (option.scalars and origin is AsyncScalarResult)
79+
or issubclass(AsyncIterator, origin)
80+
):
81+
return models, Option(scalars=option.scalars)
82+
83+
if issubclass(Sequence, origin):
84+
return models, Option(True, option.scalars, methodcall("all"))
85+
86+
return (None, None)
87+
88+
89+
class ModelParam(DependParam):
90+
@classmethod
91+
def _check_param(
92+
cls, param: Parameter, allow_types: tuple[type[Param], ...]
93+
) -> Param | None:
94+
models, option = parse_model_annotation(param.annotation)
95+
96+
if not (models and option):
97+
return
98+
99+
stat = select(*models).where(
100+
*(
101+
getattr(model, name) == param.default
102+
for model in models
103+
for name, param in model.__signature__.parameters.items()
104+
)
105+
)
106+
107+
return super()._check_param(
108+
param.replace(default=compile_dependency(stat, option)), allow_types
109+
)
110+
111+
@classmethod
112+
def _check_parameterless(
113+
cls, value: Any, allow_types: tuple[type[Param], ...]
114+
) -> Param | None:
115+
return
116+
117+
118+
Matcher.HANDLER_PARAM_TYPES = Matcher.HANDLER_PARAM_TYPES[:-1] + (
119+
ModelParam,
120+
DefaultParam,
121+
)

nonebot_plugin_orm/utils.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,21 @@
55
import logging
66
from io import StringIO
77
from pathlib import Path
8-
from typing import TypeVar
98
from contextlib import suppress
109
from functools import wraps, lru_cache
10+
from typing_extensions import Annotated
11+
from dataclasses import field, dataclass
12+
from typing import Any, TypeVar, NamedTuple
1113
from collections.abc import Callable, Iterable
14+
from inspect import Parameter, Signature, isclass
1215
from importlib.metadata import Distribution, PackageNotFoundError, distribution
1316

1417
import click
1518
from nonebot.plugin import Plugin
1619
from nonebot.params import Depends
1720
from nonebot import logger, get_driver
21+
from pydantic.typing import get_args, get_origin
22+
from sqlalchemy.sql.selectable import ExecutableReturnsRows
1823

1924
if sys.version_info >= (3, 9):
2025
from importlib.resources import files
@@ -83,6 +88,78 @@ def flush(self):
8388
pass
8489

8590

91+
@dataclass
92+
class Option:
93+
stream: bool = True
94+
scalars: bool = False
95+
result: _methodcall | None = None
96+
calls: list[_methodcall] = field(default_factory=list)
97+
98+
99+
class _methodcall(NamedTuple):
100+
name: str
101+
args: tuple
102+
kwargs: dict[str, Any]
103+
104+
105+
def methodcall(name: str, args: tuple = (), kwargs: dict[str, Any] = {}) -> _methodcall:
106+
return _methodcall(name, args, kwargs)
107+
108+
109+
def compile_dependency(statement: ExecutableReturnsRows, option: Option) -> Any:
110+
from . import async_scoped_session
111+
112+
async def dependency(*, __session: async_scoped_session, **params: Any):
113+
if option.stream:
114+
result = await __session.stream(statement, params)
115+
else:
116+
result = await __session.execute(statement, params)
117+
118+
for call in option.calls:
119+
result = getattr(result, call.name)(*call.args, **call.kwargs)
120+
121+
if option.scalars:
122+
result = result.scalars()
123+
124+
if call := option.result:
125+
result = getattr(result, call.name)(*call.args, **call.kwargs)
126+
127+
if option.stream:
128+
result = await result
129+
130+
return result
131+
132+
dependency.__signature__ = Signature(
133+
[
134+
Parameter(
135+
"__session", Parameter.KEYWORD_ONLY, annotation=async_scoped_session
136+
),
137+
*(
138+
Parameter(name, Parameter.KEYWORD_ONLY, default=depends)
139+
for name, depends in statement.compile().params.items()
140+
if isinstance(depends, DependsInner)
141+
),
142+
]
143+
)
144+
145+
return Depends(dependency)
146+
147+
148+
def toclass(cls: Any) -> type:
149+
if isclass(cls):
150+
return cls
151+
152+
origin, args = get_origin(cls), get_args(cls)
153+
154+
if origin is None:
155+
raise TypeError(f"{cls!r} is not a class or a generic type")
156+
157+
if origin is Annotated:
158+
return toclass(args[0])
159+
160+
return origin
161+
162+
86163
def return_progressbar(func: Callable[_P, Iterable[_T]]) -> Callable[_P, Iterable[_T]]:
87164
log_level = get_driver().config.log_level
88165
if isinstance(log_level, str):

0 commit comments

Comments
 (0)