|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import sys |
4 | | -from typing import Any, Union |
5 | | -from inspect import Parameter, isclass |
| 4 | +from itertools import repeat |
| 5 | +from inspect import Parameter |
| 6 | +from dataclasses import dataclass |
6 | 7 | from typing_extensions import Annotated |
7 | | -from collections.abc import Iterator, Sequence, AsyncIterator |
| 8 | +from typing import Any, Tuple, Iterator, Sequence, AsyncIterator, cast |
8 | 9 |
|
| 10 | +from nonebot.rule import Rule |
9 | 11 | from nonebot.matcher import Matcher |
| 12 | +from pydantic.fields import FieldInfo |
10 | 13 | from nonebot.dependencies import Param |
| 14 | +from nonebot.permission import Permission |
11 | 15 | from pydantic.typing import get_args, get_origin |
12 | 16 | from nonebot.params import DependParam, DefaultParam |
13 | 17 | from sqlalchemy import Row, Result, ScalarResult, select |
| 18 | +from sqlalchemy.sql.selectable import ExecutableReturnsRows |
14 | 19 | from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult |
15 | 20 |
|
16 | 21 | from .model import Model |
17 | | -from .utils import Option, toclass, methodcall, compile_dependency |
| 22 | +from .utils import Option, methodcaller, compile_dependency, generic_issubclass |
18 | 23 |
|
19 | | -if sys.version_info >= (3, 10): |
20 | | - from types import NoneType, UnionType |
21 | | -else: |
22 | | - NoneType = type(None) |
23 | | - UnionType = None |
24 | | - |
25 | | - |
26 | | -def parse_model_annotation( |
27 | | - anno: Any, |
28 | | -) -> tuple[tuple[type[Model], ...], Option] | tuple[None, None]: |
29 | | - if isclass(anno) and issubclass(anno, Model): |
30 | | - return (anno,), Option(scalars=True, result=methodcall("one_or_none")) |
31 | | - |
32 | | - origin, args = get_origin(anno), get_args(anno) |
33 | | - |
34 | | - if not (origin and args): |
35 | | - return (None, None) |
36 | | - |
37 | | - if origin is Annotated: |
38 | | - return parse_model_annotation(args[0]) |
39 | | - |
40 | | - if origin in (UnionType, Union) and len(args) == 2: |
41 | | - if args[0] is NoneType: |
42 | | - return parse_model_annotation(args[1]) |
43 | | - elif args[1] is NoneType: |
44 | | - return parse_model_annotation(args[0]) |
45 | | - |
46 | | - if not isclass(origin): |
47 | | - return (None, None) |
48 | | - |
49 | | - if origin is Row: |
50 | | - origin, args = tuple, get_args(args[0]) |
51 | | - |
52 | | - if origin is tuple and all(issubclass(arg, Model) for arg in map(toclass, args)): |
53 | | - return args, Option(result=methodcall("one_or_none")) |
54 | | - |
55 | | - models, option = parse_model_annotation(args[0]) |
56 | | - if not (models and option): |
57 | | - return (None, None) |
58 | | - |
59 | | - if option.result == methodcall("all"): |
60 | | - if issubclass(Iterator, origin): |
61 | | - return models, Option(False, option.scalars, methodcall("partitions")) |
62 | | - |
63 | | - if issubclass(AsyncIterator, origin): |
64 | | - return models, Option(True, option.scalars, methodcall("partitions")) |
65 | | - |
66 | | - if option.result != methodcall("one_or_none"): |
67 | | - return (None, None) |
68 | | - |
69 | | - if ( |
70 | | - (not option.scalars and origin is Result) |
71 | | - or (option.scalars and origin is ScalarResult) |
72 | | - or issubclass(Iterator, origin) |
73 | | - ): |
74 | | - return models, Option(False, option.scalars) |
75 | | - |
76 | | - if ( |
77 | | - (not option.scalars and origin is AsyncResult) |
78 | | - or (option.scalars and origin is AsyncScalarResult) |
79 | | - or issubclass(AsyncIterator, origin) |
80 | | - ): |
81 | | - return models, Option(scalars=option.scalars) |
82 | | - |
83 | | - if issubclass(Sequence, origin): |
84 | | - return models, Option(True, option.scalars, methodcall("all")) |
85 | | - |
86 | | - return (None, None) |
| 24 | +__all__ = ( |
| 25 | + "SQLDepends", |
| 26 | + "ORMParam", |
| 27 | +) |
87 | 28 |
|
88 | 29 |
|
89 | | -class ModelParam(DependParam): |
| 30 | +PATTERNS = { |
| 31 | + AsyncIterator[Sequence[Row[Tuple[Any, ...]]]]: Option( |
| 32 | + True, |
| 33 | + False, |
| 34 | + methodcaller("partitions"), |
| 35 | + ), |
| 36 | + AsyncIterator[Sequence[Tuple[Any, ...]]]: Option( |
| 37 | + True, |
| 38 | + False, |
| 39 | + methodcaller("partitions"), |
| 40 | + ), |
| 41 | + AsyncIterator[Sequence[Any]]: Option( |
| 42 | + True, |
| 43 | + True, |
| 44 | + methodcaller("partitions"), |
| 45 | + ), |
| 46 | + Iterator[Sequence[Row[Tuple[Any, ...]]]]: Option( |
| 47 | + False, |
| 48 | + False, |
| 49 | + methodcaller("partitions"), |
| 50 | + ), |
| 51 | + Iterator[Sequence[Tuple[Any, ...]]]: Option( |
| 52 | + False, |
| 53 | + False, |
| 54 | + methodcaller("partitions"), |
| 55 | + ), |
| 56 | + Iterator[Sequence[Any]]: Option( |
| 57 | + False, |
| 58 | + True, |
| 59 | + methodcaller("partitions"), |
| 60 | + ), |
| 61 | + AsyncResult[Tuple[Any, ...]]: Option( |
| 62 | + True, |
| 63 | + False, |
| 64 | + ), |
| 65 | + AsyncScalarResult[Any]: Option( |
| 66 | + True, |
| 67 | + True, |
| 68 | + ), |
| 69 | + Result[Tuple[Any, ...]]: Option( |
| 70 | + False, |
| 71 | + False, |
| 72 | + ), |
| 73 | + ScalarResult[Any]: Option( |
| 74 | + False, |
| 75 | + True, |
| 76 | + ), |
| 77 | + AsyncIterator[Row[Tuple[Any, ...]]]: Option( |
| 78 | + True, |
| 79 | + False, |
| 80 | + ), |
| 81 | + Iterator[Row[Tuple[Any, ...]]]: Option( |
| 82 | + False, |
| 83 | + False, |
| 84 | + ), |
| 85 | + Sequence[Row[Tuple[Any, ...]]]: Option( |
| 86 | + True, |
| 87 | + False, |
| 88 | + methodcaller("all"), |
| 89 | + ), |
| 90 | + Sequence[Tuple[Any, ...]]: Option( |
| 91 | + True, |
| 92 | + False, |
| 93 | + methodcaller("all"), |
| 94 | + ), |
| 95 | + Sequence[Any]: Option( |
| 96 | + True, |
| 97 | + True, |
| 98 | + methodcaller("all"), |
| 99 | + ), |
| 100 | + Tuple[Any, ...]: Option( |
| 101 | + True, |
| 102 | + False, |
| 103 | + methodcaller("one_or_none"), |
| 104 | + ), |
| 105 | + Any: Option( |
| 106 | + True, |
| 107 | + True, |
| 108 | + methodcaller("one_or_none"), |
| 109 | + ), |
| 110 | +} |
| 111 | + |
| 112 | + |
| 113 | +@dataclass |
| 114 | +class SQLDependsInner: |
| 115 | + dependency: ExecutableReturnsRows |
| 116 | + |
| 117 | + if sys.version_info >= (3, 10): |
| 118 | + from dataclasses import KW_ONLY |
| 119 | + |
| 120 | + _: KW_ONLY |
| 121 | + |
| 122 | + use_cache: bool = True |
| 123 | + validate: bool | FieldInfo = False |
| 124 | + |
| 125 | + |
| 126 | +def SQLDepends( |
| 127 | + dependency: ExecutableReturnsRows, |
| 128 | + *, |
| 129 | + use_cache: bool = True, |
| 130 | + validate: bool | FieldInfo = False, |
| 131 | +) -> Any: |
| 132 | + return SQLDependsInner(dependency, use_cache=use_cache, validate=validate) |
| 133 | + |
| 134 | + |
| 135 | +class ORMParam(DependParam): |
90 | 136 | @classmethod |
91 | 137 | def _check_param( |
92 | 138 | cls, param: Parameter, allow_types: tuple[type[Param], ...] |
93 | 139 | ) -> Param | None: |
94 | | - models, option = parse_model_annotation(param.annotation) |
95 | | - |
96 | | - if not (models and option): |
97 | | - return |
| 140 | + type_annotation, depends_inner = param.annotation, None |
| 141 | + if get_origin(param.annotation) is Annotated: |
| 142 | + type_annotation, *extra_args = get_args(param.annotation) |
| 143 | + depends_inner = next( |
| 144 | + (x for x in reversed(extra_args) if isinstance(x, SQLDependsInner)), |
| 145 | + None, |
| 146 | + ) |
98 | 147 |
|
99 | | - stat = select(*models).where( |
100 | | - *( |
101 | | - getattr(model, name) == param.default |
102 | | - for model in models |
103 | | - for name, param in model.__signature__.parameters.items() |
| 148 | + if isinstance(param.default, SQLDependsInner): |
| 149 | + depends_inner = param.default |
| 150 | + |
| 151 | + for pattern, option in PATTERNS.items(): |
| 152 | + if models := generic_issubclass(pattern, type_annotation): |
| 153 | + break |
| 154 | + else: |
| 155 | + models, option = None, Option() |
| 156 | + |
| 157 | + if not isinstance(models, tuple): |
| 158 | + models = (models,) |
| 159 | + |
| 160 | + if depends_inner is not None: |
| 161 | + dependency = compile_dependency(depends_inner.dependency, option) |
| 162 | + elif all(map(issubclass, models, repeat(Model))): |
| 163 | + models = cast(Tuple[Model, ...], models) |
| 164 | + dependency = compile_dependency( |
| 165 | + select(*models).where( |
| 166 | + *( |
| 167 | + getattr(model, name) == param.default |
| 168 | + for model in models |
| 169 | + for name, param in model.__signature__.parameters.items() |
| 170 | + ) |
| 171 | + ), |
| 172 | + option, |
104 | 173 | ) |
105 | | - ) |
| 174 | + else: |
| 175 | + return |
106 | 176 |
|
107 | | - return super()._check_param( |
108 | | - param.replace(default=compile_dependency(stat, option)), allow_types |
109 | | - ) |
| 177 | + return super()._check_param(param.replace(default=dependency), allow_types) |
110 | 178 |
|
111 | 179 | @classmethod |
112 | | - def _check_parameterless( |
113 | | - cls, value: Any, allow_types: tuple[type[Param], ...] |
114 | | - ) -> Param | None: |
| 180 | + def _check_parameterless(cls, *_) -> Param | None: |
115 | 181 | return |
116 | 182 |
|
117 | 183 |
|
| 184 | +for cls in (Rule, Permission): |
| 185 | + cls.HANDLER_PARAM_TYPES.insert(-1, ORMParam) |
| 186 | + |
118 | 187 | Matcher.HANDLER_PARAM_TYPES = Matcher.HANDLER_PARAM_TYPES[:-1] + ( |
119 | | - ModelParam, |
| 188 | + ORMParam, |
120 | 189 | DefaultParam, |
121 | 190 | ) |
0 commit comments