Skip to content

Commit d0fd030

Browse files
committed
♻️ refactor(sqla): ORMParam
1 parent 95624a6 commit d0fd030

File tree

4 files changed

+233
-387
lines changed

4 files changed

+233
-387
lines changed

nonebot_plugin_orm/__init__.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,15 @@
2828
__all__ = (
2929
# __init__
3030
"init_orm",
31+
"get_session",
32+
"AsyncSession",
3133
"get_scoped_session",
32-
# sql
33-
"one",
34-
"all_",
35-
"first",
36-
"select",
37-
"scalars",
38-
"scalar_all",
39-
"scalar_one",
40-
"one_or_none",
41-
"scalar_first",
42-
"one_or_create",
43-
"scalar_one_or_none",
34+
"async_scoped_session",
4435
# model
4536
"Model",
37+
# param
38+
"SQLDepends",
39+
"ORMParam",
4640
# config
4741
"Config",
4842
"config",
@@ -222,7 +216,6 @@ def _init_logger():
222216
l.setLevel(level)
223217

224218

225-
from .sql import *
226219
from .model import *
227220
from .param import *
228221
from .config import *

nonebot_plugin_orm/param.py

Lines changed: 159 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,190 @@
11
from __future__ import annotations
22

33
import sys
4-
from typing import Any, Union
5-
from inspect import Parameter, isclass
4+
from itertools import repeat
5+
from inspect import Parameter
6+
from dataclasses import dataclass
67
from typing_extensions import Annotated
7-
from collections.abc import Iterator, Sequence, AsyncIterator
8+
from typing import Any, Tuple, Iterator, Sequence, AsyncIterator, cast
89

10+
from nonebot.rule import Rule
911
from nonebot.matcher import Matcher
12+
from pydantic.fields import FieldInfo
1013
from nonebot.dependencies import Param
14+
from nonebot.permission import Permission
1115
from pydantic.typing import get_args, get_origin
1216
from nonebot.params import DependParam, DefaultParam
1317
from sqlalchemy import Row, Result, ScalarResult, select
18+
from sqlalchemy.sql.selectable import ExecutableReturnsRows
1419
from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult
1520

1621
from .model import Model
17-
from .utils import Option, toclass, methodcall, compile_dependency
22+
from .utils import Option, methodcaller, compile_dependency, generic_issubclass
1823

19-
if sys.version_info >= (3, 10):
20-
from types import NoneType, UnionType
21-
else:
22-
NoneType = type(None)
23-
UnionType = None
24-
25-
26-
def parse_model_annotation(
27-
anno: Any,
28-
) -> tuple[tuple[type[Model], ...], Option] | tuple[None, None]:
29-
if isclass(anno) and issubclass(anno, Model):
30-
return (anno,), Option(scalars=True, result=methodcall("one_or_none"))
31-
32-
origin, args = get_origin(anno), get_args(anno)
33-
34-
if not (origin and args):
35-
return (None, None)
36-
37-
if origin is Annotated:
38-
return parse_model_annotation(args[0])
39-
40-
if origin in (UnionType, Union) and len(args) == 2:
41-
if args[0] is NoneType:
42-
return parse_model_annotation(args[1])
43-
elif args[1] is NoneType:
44-
return parse_model_annotation(args[0])
45-
46-
if not isclass(origin):
47-
return (None, None)
48-
49-
if origin is Row:
50-
origin, args = tuple, get_args(args[0])
51-
52-
if origin is tuple and all(issubclass(arg, Model) for arg in map(toclass, args)):
53-
return args, Option(result=methodcall("one_or_none"))
54-
55-
models, option = parse_model_annotation(args[0])
56-
if not (models and option):
57-
return (None, None)
58-
59-
if option.result == methodcall("all"):
60-
if issubclass(Iterator, origin):
61-
return models, Option(False, option.scalars, methodcall("partitions"))
62-
63-
if issubclass(AsyncIterator, origin):
64-
return models, Option(True, option.scalars, methodcall("partitions"))
65-
66-
if option.result != methodcall("one_or_none"):
67-
return (None, None)
68-
69-
if (
70-
(not option.scalars and origin is Result)
71-
or (option.scalars and origin is ScalarResult)
72-
or issubclass(Iterator, origin)
73-
):
74-
return models, Option(False, option.scalars)
75-
76-
if (
77-
(not option.scalars and origin is AsyncResult)
78-
or (option.scalars and origin is AsyncScalarResult)
79-
or issubclass(AsyncIterator, origin)
80-
):
81-
return models, Option(scalars=option.scalars)
82-
83-
if issubclass(Sequence, origin):
84-
return models, Option(True, option.scalars, methodcall("all"))
85-
86-
return (None, None)
24+
__all__ = (
25+
"SQLDepends",
26+
"ORMParam",
27+
)
8728

8829

89-
class ModelParam(DependParam):
30+
PATTERNS = {
31+
AsyncIterator[Sequence[Row[Tuple[Any, ...]]]]: Option(
32+
True,
33+
False,
34+
methodcaller("partitions"),
35+
),
36+
AsyncIterator[Sequence[Tuple[Any, ...]]]: Option(
37+
True,
38+
False,
39+
methodcaller("partitions"),
40+
),
41+
AsyncIterator[Sequence[Any]]: Option(
42+
True,
43+
True,
44+
methodcaller("partitions"),
45+
),
46+
Iterator[Sequence[Row[Tuple[Any, ...]]]]: Option(
47+
False,
48+
False,
49+
methodcaller("partitions"),
50+
),
51+
Iterator[Sequence[Tuple[Any, ...]]]: Option(
52+
False,
53+
False,
54+
methodcaller("partitions"),
55+
),
56+
Iterator[Sequence[Any]]: Option(
57+
False,
58+
True,
59+
methodcaller("partitions"),
60+
),
61+
AsyncResult[Tuple[Any, ...]]: Option(
62+
True,
63+
False,
64+
),
65+
AsyncScalarResult[Any]: Option(
66+
True,
67+
True,
68+
),
69+
Result[Tuple[Any, ...]]: Option(
70+
False,
71+
False,
72+
),
73+
ScalarResult[Any]: Option(
74+
False,
75+
True,
76+
),
77+
AsyncIterator[Row[Tuple[Any, ...]]]: Option(
78+
True,
79+
False,
80+
),
81+
Iterator[Row[Tuple[Any, ...]]]: Option(
82+
False,
83+
False,
84+
),
85+
Sequence[Row[Tuple[Any, ...]]]: Option(
86+
True,
87+
False,
88+
methodcaller("all"),
89+
),
90+
Sequence[Tuple[Any, ...]]: Option(
91+
True,
92+
False,
93+
methodcaller("all"),
94+
),
95+
Sequence[Any]: Option(
96+
True,
97+
True,
98+
methodcaller("all"),
99+
),
100+
Tuple[Any, ...]: Option(
101+
True,
102+
False,
103+
methodcaller("one_or_none"),
104+
),
105+
Any: Option(
106+
True,
107+
True,
108+
methodcaller("one_or_none"),
109+
),
110+
}
111+
112+
113+
@dataclass
114+
class SQLDependsInner:
115+
dependency: ExecutableReturnsRows
116+
117+
if sys.version_info >= (3, 10):
118+
from dataclasses import KW_ONLY
119+
120+
_: KW_ONLY
121+
122+
use_cache: bool = True
123+
validate: bool | FieldInfo = False
124+
125+
126+
def SQLDepends(
127+
dependency: ExecutableReturnsRows,
128+
*,
129+
use_cache: bool = True,
130+
validate: bool | FieldInfo = False,
131+
) -> Any:
132+
return SQLDependsInner(dependency, use_cache=use_cache, validate=validate)
133+
134+
135+
class ORMParam(DependParam):
90136
@classmethod
91137
def _check_param(
92138
cls, param: Parameter, allow_types: tuple[type[Param], ...]
93139
) -> Param | None:
94-
models, option = parse_model_annotation(param.annotation)
95-
96-
if not (models and option):
97-
return
140+
type_annotation, depends_inner = param.annotation, None
141+
if get_origin(param.annotation) is Annotated:
142+
type_annotation, *extra_args = get_args(param.annotation)
143+
depends_inner = next(
144+
(x for x in reversed(extra_args) if isinstance(x, SQLDependsInner)),
145+
None,
146+
)
98147

99-
stat = select(*models).where(
100-
*(
101-
getattr(model, name) == param.default
102-
for model in models
103-
for name, param in model.__signature__.parameters.items()
148+
if isinstance(param.default, SQLDependsInner):
149+
depends_inner = param.default
150+
151+
for pattern, option in PATTERNS.items():
152+
if models := generic_issubclass(pattern, type_annotation):
153+
break
154+
else:
155+
models, option = None, Option()
156+
157+
if not isinstance(models, tuple):
158+
models = (models,)
159+
160+
if depends_inner is not None:
161+
dependency = compile_dependency(depends_inner.dependency, option)
162+
elif all(map(issubclass, models, repeat(Model))):
163+
models = cast(Tuple[Model, ...], models)
164+
dependency = compile_dependency(
165+
select(*models).where(
166+
*(
167+
getattr(model, name) == param.default
168+
for model in models
169+
for name, param in model.__signature__.parameters.items()
170+
)
171+
),
172+
option,
104173
)
105-
)
174+
else:
175+
return
106176

107-
return super()._check_param(
108-
param.replace(default=compile_dependency(stat, option)), allow_types
109-
)
177+
return super()._check_param(param.replace(default=dependency), allow_types)
110178

111179
@classmethod
112-
def _check_parameterless(
113-
cls, value: Any, allow_types: tuple[type[Param], ...]
114-
) -> Param | None:
180+
def _check_parameterless(cls, *_) -> Param | None:
115181
return
116182

117183

184+
for cls in (Rule, Permission):
185+
cls.HANDLER_PARAM_TYPES.insert(-1, ORMParam)
186+
118187
Matcher.HANDLER_PARAM_TYPES = Matcher.HANDLER_PARAM_TYPES[:-1] + (
119-
ModelParam,
188+
ORMParam,
120189
DefaultParam,
121190
)

0 commit comments

Comments
 (0)