Skip to content

Commit c313658

Browse files
committed
SubcommandPermExtension
1 parent c7c652f commit c313658

File tree

4 files changed

+148
-8
lines changed

4 files changed

+148
-8
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from typing import Protocol
5+
6+
from arclet.alconna import Arparma, CompSession, SubcommandResult
7+
from nonebot.internal.adapter import Bot, Event
8+
9+
from nonebot_plugin_alconna import Extension
10+
11+
12+
class Checker(Protocol):
13+
async def __call__(self, bot: Bot, event: Event, permission: str) -> bool: ...
14+
15+
16+
class SubcommandPermExtension(Extension):
17+
"""
18+
用于简易检查调用者是否有命令权限的拓展。
19+
20+
Example:
21+
>>> from nonebot_plugin_alconna.builtins.extensions.permission import SubcommandPermExtension
22+
>>>
23+
>>> matcher = on_alconna("...", extensions=[SubcommandPermExtension(...)])
24+
"""
25+
26+
@property
27+
def priority(self) -> int:
28+
return 20
29+
30+
def __init__(self, checker: Checker, include_options: bool = False) -> None:
31+
"""
32+
Args:
33+
checker: 权限检查函数,接受 bot、event、permission 三个参数,返回是否有权限的布尔值
34+
include_options: 是否需要选项的权限检查
35+
"""
36+
self.checker = checker
37+
self.include_options = include_options
38+
39+
@property
40+
def id(self) -> str:
41+
return "builtins.extensions.permission:SubcommandPermExtension"
42+
43+
async def permission_check(self, bot: Bot, event: Event, medium: Arparma | CompSession) -> bool:
44+
if isinstance(medium, CompSession):
45+
return True
46+
base = [f"command.{medium.source.name}"]
47+
if self.include_options:
48+
base.extend(f"command.{medium.source.name}.$options.{opt}" for opt in medium.options)
49+
50+
def gen_permissions(subcommands: dict[str, SubcommandResult], prefix: str):
51+
for name, result in subcommands.items():
52+
current_perm = f"{prefix}.{name}"
53+
yield current_perm
54+
if self.include_options and result.options:
55+
for opt in result.options:
56+
yield f"{current_perm}.$options.{opt}"
57+
if result.subcommands:
58+
yield from gen_permissions(result.subcommands, current_perm)
59+
60+
base.extend(gen_permissions(medium.subcommands, f"command.{medium.source.name}"))
61+
tasks = [self.checker(bot, event, perm) for perm in base]
62+
results = await asyncio.gather(*tasks)
63+
return all(results)
64+
65+
66+
__extension__ = SubcommandPermExtension

src/nonebot_plugin_alconna/extension.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, Union, final, overload
1212
from weakref import finalize
1313

14-
from arclet.alconna import Alconna, Arparma
14+
from arclet.alconna import Alconna, Arparma, CompSession
1515
from nonebot import get_plugin_config
1616
from nonebot.adapters import Bot, Event, Message
1717
from nonebot.compat import PydanticUndefined
@@ -133,7 +133,7 @@ async def receive_wrapper(self, bot: Bot, event: Event, command: Alconna, receiv
133133
"""接收消息后的钩子函数。"""
134134
return receive
135135

136-
async def permission_check(self, bot: Bot, event: Event, command: Alconna) -> bool:
136+
async def permission_check(self, bot: Bot, event: Event, medium: Arparma | CompSession) -> bool:
137137
"""命令首次解析并确认头部匹配(即确认选择响应)时对发送者的权限判断"""
138138
return True
139139

@@ -220,10 +220,10 @@ async def receive_wrapper(self, bot: Bot, event: Event, command: Alconna, receiv
220220
res = await ext.receive_wrapper(bot, event, command, res)
221221
return res
222222

223-
async def permission_check(self, bot: Bot, event: Event, command: Alconna) -> bool:
223+
async def permission_check(self, bot: Bot, event: Event, medium: Arparma | CompSession) -> bool:
224224
for ext in self.context:
225225
if ext._overrides["permission_check"]:
226-
if await ext.permission_check(bot, event, command) is False:
226+
if await ext.permission_check(bot, event, medium) is False:
227227
return False
228228
continue
229229
return True

src/nonebot_plugin_alconna/rule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ async def handle(
241241
if res:
242242
interface.exit()
243243
return res
244-
if not await selected.permission_check(bot, event, cmd):
244+
if not await selected.permission_check(bot, event, interface):
245245
return False
246246

247247
res = Arparma(
@@ -354,7 +354,7 @@ async def __call__(
354354
return False
355355
if self.skip and may_help_text:
356356
return False
357-
if not await selected.permission_check(bot, event, cmd):
357+
if not await selected.permission_check(bot, event, arp):
358358
return False
359359
state[ALCONNA_RESULT] = CommandResult(result=arp, output=may_help_text)
360360
state[ALCONNA_EXEC_RESULT] = cmd.exec_result

tests/test_extension.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any
22

3-
from arclet.alconna import Alconna, Args, Arparma
3+
from arclet.alconna import Alconna, Args, Arparma, Subcommand
44
from nonebot import get_adapter
55
from nonebot.internal.adapter import Event
66
from nonebot.params import Depends
@@ -35,7 +35,7 @@ def priority(self) -> int:
3535
def id(self) -> str:
3636
return "demo"
3737

38-
async def permission_check(self, bot, event, command):
38+
async def permission_check(self, bot, event, medium):
3939
return await self.inject(check_dep)
4040

4141
def before_catch(self, name: str, annotation: Any, default: Any) -> bool:
@@ -79,3 +79,77 @@ async def h(a: float, b: float, hello: str, world: str, test: str):
7979
event = fake_group_message_event_v11(message=Message("add 1.3 2.4"), user_id=456)
8080
ctx.receive_event(bot, event)
8181
ctx.should_call_send(event, "权限不足!")
82+
83+
84+
@pytest.mark.asyncio()
85+
async def test_extension_permission(app: App):
86+
from nonebot.adapters.onebot.v11 import Adapter, Bot, Message
87+
88+
from nonebot_plugin_alconna import on_alconna
89+
from nonebot_plugin_alconna.builtins.extensions.permission import SubcommandPermExtension
90+
91+
cmd = Alconna(
92+
"calc",
93+
Subcommand(
94+
"add",
95+
Args["a", float]["b", float],
96+
),
97+
Subcommand(
98+
"mul",
99+
Args["a", float]["b", float],
100+
),
101+
Subcommand(
102+
"div",
103+
Args["a", float]["b", float],
104+
),
105+
Subcommand(
106+
"sub",
107+
Args["a", float]["b", float],
108+
),
109+
)
110+
111+
async def checker(bot, event, permission):
112+
user_id = event.get_user_id()
113+
if user_id == "123":
114+
return {
115+
"command.calc": True,
116+
"command.calc.add": True,
117+
"command.calc.sub": True,
118+
"command.calc.mul": False,
119+
"command.calc.div": False,
120+
}[permission]
121+
if user_id == "456":
122+
return {
123+
"command.calc": True,
124+
"command.calc.add": False,
125+
"command.calc.sub": True,
126+
"command.calc.mul": False,
127+
"command.calc.div": True,
128+
}[permission]
129+
return True
130+
131+
mat = on_alconna(cmd, extensions=[SubcommandPermExtension(checker)])
132+
133+
@mat.handle()
134+
async def h(a: float, b: float):
135+
await mat.send(f"Result: {a} and {b}")
136+
137+
async with app.test_matcher(mat) as ctx: # type: ignore
138+
adapter = get_adapter(Adapter)
139+
bot = ctx.create_bot(base=Bot, adapter=adapter)
140+
141+
event = fake_group_message_event_v11(message=Message("calc add 1 2"), user_id=123)
142+
ctx.receive_event(bot, event)
143+
ctx.should_call_send(event, "Result: 1.0 and 2.0")
144+
145+
event = fake_group_message_event_v11(message=Message("calc div 1 2"), user_id=123)
146+
ctx.receive_event(bot, event)
147+
ctx.should_not_pass_rule()
148+
149+
event = fake_group_message_event_v11(message=Message("calc div 5 3"), user_id=456)
150+
ctx.receive_event(bot, event)
151+
ctx.should_call_send(event, "Result: 5.0 and 3.0")
152+
153+
event = fake_group_message_event_v11(message=Message("calc add 1 2"), user_id=456)
154+
ctx.receive_event(bot, event)
155+
ctx.should_not_pass_rule()

0 commit comments

Comments
 (0)