Skip to content

Commit 7883dae

Browse files
committed
✨ Extension.inject
1 parent 7c2fde6 commit 7883dae

File tree

6 files changed

+106
-26
lines changed

6 files changed

+106
-26
lines changed

src/nonebot_plugin_alconna/extension.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,30 @@
22

33
from abc import ABCMeta, abstractmethod
44
import asyncio
5+
from contextlib import AsyncExitStack
56
from dataclasses import dataclass
67
import functools
78
import importlib as imp
9+
import inspect
810
import re
9-
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, Union
11+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, Union, final, overload
1012
from weakref import finalize
1113

1214
from arclet.alconna import Alconna, Arparma
1315
from nonebot import get_plugin_config
1416
from nonebot.adapters import Bot, Event, Message
1517
from nonebot.compat import PydanticUndefined
16-
from nonebot.typing import T_State
18+
from nonebot.dependencies import Dependent, Param
19+
from nonebot.internal.params import DependencyCache, DependParam, DependsInner
20+
from nonebot.typing import T_State, _DependentCallable
21+
from pydantic.fields import FieldInfo
1722
from tarina import LRU, lang
1823

1924
from .config import Config
2025
from .uniseg import UniMessage, get_message_id
2126

2227
OutputType = Literal["help", "shortcut", "completion", "error"]
28+
T = TypeVar("T")
2329
TM = TypeVar("TM", bound=Union[str, Message, UniMessage])
2430
TE = TypeVar("TE", bound=Event)
2531

@@ -57,6 +63,8 @@ def __init_subclass__(cls, **kwargs):
5763
"catch": cls.catch != Extension.catch and cls.before_catch != Extension.before_catch,
5864
}
5965

66+
executor: ExtensionExecutor
67+
6068
@property
6169
@abstractmethod
6270
def priority(self) -> int:
@@ -77,6 +85,40 @@ def namespace(self) -> str:
7785
def validate(self, bot: Bot, event: Event) -> bool:
7886
return event.get_type() == "message"
7987

88+
@overload
89+
async def inject(
90+
self, dependent: Dependent[T], *, use_cache: bool = True, validate: bool | FieldInfo = False
91+
) -> T: ...
92+
93+
@overload
94+
async def inject(self, dependent: tuple[str, type[T]]) -> T: ...
95+
96+
@overload
97+
async def inject(self, dependent: Any) -> Any: ...
98+
99+
@final
100+
async def inject(self, dependent: Any, use_cache: bool = True, validate: bool | FieldInfo = False) -> Any:
101+
# assert isinstance(dependent, (Dependent, DependsInner)), "仅支持 Dependent 或 DependsInner 类型的依赖注入"
102+
if isinstance(dependent, DependsInner):
103+
if not dependent.dependency:
104+
raise ValueError("DependsInner 未绑定任何依赖")
105+
use_cache = dependent.use_cache
106+
validate = dependent.validate
107+
dependent = Dependent.parse(call=dependent.dependency, allow_types=self.executor.params)
108+
param = DependParam(dependent=dependent, use_cache=use_cache, validate=validate)
109+
elif isinstance(dependent, Dependent):
110+
param = DependParam(dependent=dependent, use_cache=use_cache, validate=validate)
111+
else:
112+
for allow_type in self.executor.params:
113+
if param := allow_type._check_param(
114+
inspect.Parameter(dependent[0], inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=dependent[1]),
115+
self.executor.params,
116+
):
117+
break
118+
else:
119+
raise ValueError(f"Unknown parameter {dependent[0]} with type {dependent[1]}")
120+
return await self.executor._dependent_executor(param)
121+
80122
async def output_converter(self, output_type: OutputType, content: str) -> UniMessage:
81123
"""依据输出信息的类型,将字符串转换为消息对象以便发送。"""
82124
return UniMessage(content)
@@ -224,16 +266,39 @@ async def send_wrapper(self, bot: Bot, event: Event, send: TM) -> TM:
224266
return res
225267

226268

269+
class _DependentExecutor:
270+
def __init__(
271+
self,
272+
bot: Bot,
273+
event: Event,
274+
state: T_State,
275+
stack: AsyncExitStack | None = None,
276+
dependency_cache: dict[_DependentCallable[Any], DependencyCache] | None = None,
277+
):
278+
self.bot = bot
279+
self.event = event
280+
self.state = state
281+
self.stack = stack
282+
self.dependency_cache = dependency_cache or {}
283+
284+
async def __call__(self, param: Param):
285+
return await param._solve(
286+
stack=self.stack, dependency_cache=self.dependency_cache, bot=self.bot, event=self.event, state=self.state
287+
)
288+
289+
227290
class ExtensionExecutor(SelectedExtensions):
228291
globals: ClassVar[list[type[Extension] | Extension]] = [DefaultExtension()]
229292
_rule: AlconnaRule
293+
_dependent_executor: _DependentExecutor
230294

231295
def __init__(
232296
self,
233297
rule: AlconnaRule,
234298
extensions: list[type[Extension] | Extension] | None = None,
235299
excludes: list[str | type[Extension]] | None = None,
236300
):
301+
self.params: tuple[type[Param], ...] = ()
237302
self.extensions: list[Extension] = []
238303
for ext in self.globals:
239304
if isinstance(ext, type):
@@ -284,6 +349,8 @@ def _callback(self, *append_global_ext: type[Extension] | Extension):
284349
def select(self, bot: Bot, event: Event) -> SelectedExtensions:
285350
context = [ext for ext in self.extensions if ext.validate(bot, event)]
286351
context.sort(key=lambda ext: ext.priority)
352+
for ext in context:
353+
ext.executor = self
287354
return SelectedExtensions(context)
288355

289356
def before_catch(self, name: str, annotation: Any, default: Any) -> bool:

src/nonebot_plugin_alconna/matcher.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,8 +1012,6 @@ def on_alconna(
10121012
skip_for_unmatch,
10131013
auto_send_output,
10141014
comp_config,
1015-
extensions,
1016-
exclude_ext,
10171015
use_origin,
10181016
use_cmd_start,
10191017
use_cmd_sep,
@@ -1022,13 +1020,14 @@ def on_alconna(
10221020
rule,
10231021
after_rule,
10241022
)
1025-
executor = _rule.executor
1026-
params = (
1023+
_rule.executor = executor = ExtensionExecutor(_rule, extensions, exclude_ext)
1024+
executor.params = params = (
10271025
ExtensionParam.new(executor),
10281026
*Matcher.HANDLER_PARAM_TYPES[:-1],
10291027
AlconnaParam,
10301028
DefaultParam,
10311029
)
1030+
executor.post_init(command)
10321031
source = get_matcher_source(_depth + 1)
10331032
NewMatcher = type(
10341033
AlconnaMatcher.__name__,

src/nonebot_plugin_alconna/params.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _check_param(cls, param: inspect.Parameter, allow_types: tuple[type[Param],
232232
return cls(param.default, name=param.name, type=param.annotation, validate=True)
233233
return None
234234

235-
async def _solve(self, matcher: Matcher, event: Event, state: T_State, **kwargs: Any) -> Any:
235+
async def _solve(self, event: Event, state: T_State, **kwargs: Any) -> Any:
236236
res = await self.executor.catch(event, state, self.extra["name"], self.extra["type"], self.default)
237237
if res is not PydanticUndefined:
238238
return res
@@ -275,7 +275,8 @@ def _check_param(cls, param: inspect.Parameter, allow_types: tuple[type[Param],
275275
return cls(..., type=Literal["context"])
276276
return cls(param.default, validate=True, name=param.name, type=param.annotation)
277277

278-
async def _solve(self, matcher: Matcher, event: Event, state: T_State, **kwargs: Any) -> Any:
278+
async def _solve(self, event: Event, state: T_State, **kwargs: Any) -> Any:
279+
matcher = kwargs.get("matcher", object())
279280
t = self.extra["type"]
280281
if ALCONNA_RESULT not in state:
281282
return self.default if self.default not in (..., Empty) else PydanticUndefined

src/nonebot_plugin_alconna/rule.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from .config import Config
2020
from .consts import ALCONNA_EXEC_RESULT, ALCONNA_EXTENSION, ALCONNA_RESULT, log
21-
from .extension import Extension, ExtensionExecutor, SelectedExtensions
21+
from .extension import ExtensionExecutor, SelectedExtensions, _DependentExecutor
2222
from .i18n import Lang
2323
from .model import CommandResult, CompConfig
2424
from .uniseg import UniMessage, UniMsg
@@ -58,6 +58,8 @@ class AlconnaRule:
5858
use_cmd_start: 是否使用 nb 全局配置里的命令前缀
5959
"""
6060

61+
executor: ExtensionExecutor
62+
6163
__slots__ = (
6264
"_comp_help",
6365
"_hide_tabs",
@@ -82,8 +84,6 @@ def __init__(
8284
skip_for_unmatch: bool = True,
8385
auto_send_output: Optional[bool] = None,
8486
comp_config: Optional[Union[CompConfig, bool]] = None,
85-
extensions: Optional[list[Union[type[Extension], Extension]]] = None,
86-
exclude_ext: Optional[list[Union[type[Extension], str]]] = None,
8787
use_origin: Optional[bool] = None,
8888
use_cmd_start: Optional[bool] = None,
8989
use_cmd_sep: Optional[bool] = None,
@@ -152,8 +152,6 @@ def _update(cmd_id: int):
152152
for alias in _aliases:
153153
command.shortcut(alias, prefix=True, compact=None)
154154
self.skip = skip_for_unmatch
155-
self.executor = ExtensionExecutor(self, extensions, exclude_ext)
156-
self.executor.post_init(command)
157155
self._path = command.path
158156
self._namespace = command.namespace
159157
self._tasks: dict[str, asyncio.Task] = {}
@@ -297,11 +295,13 @@ async def __call__(
297295
stack: Optional[AsyncExitStack] = None,
298296
dependency_cache: Optional[dict[_DependentCallable[Any], DependencyCache]] = None,
299297
) -> bool:
300-
if not await self.before_rules(bot, event, state, stack, dependency_cache):
298+
299+
if self.before_rules.checkers and not await self.before_rules(bot, event, state, stack, dependency_cache):
301300
return False
302301
if event.get_type() == "meta_event":
303302
return False
304303
selected = self.executor.select(bot, event)
304+
self.executor._dependent_executor = _DependentExecutor(bot, event, state, stack, dependency_cache)
305305
if not (msg := await selected.message_provider(event, state, bot, self.use_origin)):
306306
return False
307307
if not self.response_self and check_self_send(bot, event):
@@ -356,13 +356,13 @@ async def __call__(
356356
return False
357357
if not await selected.permission_check(bot, event, cmd):
358358
return False
359-
await selected.parse_wrapper(bot, state, event, arp)
360359
state[ALCONNA_RESULT] = CommandResult(result=arp, output=may_help_text)
361360
state[ALCONNA_EXEC_RESULT] = cmd.exec_result
362361
state[ALCONNA_EXTENSION] = selected
363-
if not await self.after_rules(bot, event, state, stack, dependency_cache):
364-
return False
365-
return True
362+
await selected.parse_wrapper(bot, state, event, arp)
363+
return not (
364+
self.after_rules.checkers and not await self.after_rules(bot, event, state, stack, dependency_cache)
365+
)
366366

367367
async def send(self, text: str, bot: Bot, event: Event, arp: Arparma) -> Any:
368368
_t = str(arp.error_info) if isinstance(arp.error_info, SpecialOptionTriggered) else "error"

src/nonebot_plugin_alconna/uniseg/message.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from json import dumps, loads
77
from pathlib import Path
88
from types import FunctionType
9-
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, TypeVar, Union, Protocol
9+
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, Protocol, TypeVar, Union
1010
from typing_extensions import Self, SupportsIndex, deprecated
1111

1212
from nonebot.exception import FinishedException

tests/test_extension.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import Any
22

3-
from arclet.alconna import Alconna, Args
3+
from arclet.alconna import Alconna, Args, Arparma
44
from nonebot import get_adapter
5-
from nonebot.adapters.onebot.v11 import Adapter, Bot, Message
65
from nonebot.internal.adapter import Event
6+
from nonebot.params import Depends
7+
from nonebot.typing import T_State
78
from nonebug import App
89
import pytest
910

@@ -12,9 +13,18 @@
1213

1314
@pytest.mark.asyncio()
1415
async def test_extension(app: App):
15-
from nonebot.adapters.onebot.v11 import MessageEvent
16+
from nonebot.adapters.onebot.v11 import Adapter, Bot, Message, MessageEvent
1617

1718
from nonebot_plugin_alconna import Extension, Interface, UniMessage, on_alconna
19+
from nonebot_plugin_alconna.params import Match
20+
21+
async def perm_check(bot, event):
22+
if event.get_user_id() != "123":
23+
await bot.send(event, "权限不足!")
24+
return False
25+
return True
26+
27+
check_dep = Depends(perm_check, use_cache=True)
1828

1929
class DemoExtension(Extension):
2030
@property
@@ -26,10 +36,7 @@ def id(self) -> str:
2636
return "demo"
2737

2838
async def permission_check(self, bot, event, command):
29-
if event.get_user_id() != "123":
30-
await bot.send(event, "权限不足!")
31-
return False
32-
return True
39+
return await self.inject(check_dep)
3340

3441
def before_catch(self, name: str, annotation: Any, default: Any) -> bool:
3542
return annotation is str
@@ -47,6 +54,12 @@ async def catch(self, interface: Interface[MessageEvent]):
4754
}.get(interface.name, interface.name)
4855
return None
4956

57+
async def parse_wrapper(self, bot: Bot, state: T_State, event: Event, res: Arparma) -> None:
58+
a = await self.inject(("a", Match[float]))
59+
b = await self.inject(("b", Match[float]))
60+
assert a.result == 1.3
61+
assert b.result == 2.4
62+
5063
add = on_alconna(Alconna("add", Args["a", float]["b", float]), extensions=[DemoExtension], comp_config={})
5164

5265
@add.handle()

0 commit comments

Comments
 (0)