Skip to content

Commit 9f9fd5f

Browse files
committed
🐛 extension select
resolve #94
1 parent 3a046c3 commit 9f9fd5f

File tree

5 files changed

+99
-129
lines changed

5 files changed

+99
-129
lines changed

src/nonebot_plugin_alconna/extension.py

Lines changed: 72 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import importlib as imp
77
from weakref import finalize
88
from dataclasses import dataclass
9-
from typing_extensions import Self
109
from abc import ABCMeta, abstractmethod
1110
from typing import TYPE_CHECKING, Any, Union, Generic, Literal, TypeVar, ClassVar
1211

@@ -129,82 +128,9 @@ def id(self) -> str:
129128
unimsg_origin_cache: LRU[str, UniMessage] = LRU(16)
130129

131130

132-
class ExtensionExecutor:
133-
globals: ClassVar[list[type[Extension] | Extension]] = [DefaultExtension()]
134-
_rule: AlconnaRule
135-
136-
def __init__(
137-
self,
138-
rule: AlconnaRule,
139-
extensions: list[type[Extension] | Extension] | None = None,
140-
excludes: list[str | type[Extension]] | None = None,
141-
):
142-
self.extensions: list[Extension] = []
143-
for ext in self.globals:
144-
if isinstance(ext, type):
145-
self.extensions.append(ext())
146-
else:
147-
self.extensions.append(ext)
148-
for ext in extensions or []:
149-
if isinstance(ext, type):
150-
self.extensions.append(ext())
151-
else:
152-
self.extensions.append(ext)
153-
for exl in excludes or []:
154-
if isinstance(exl, str) and exl.startswith("!"):
155-
raise ValueError(lang.require("nbp-alc", "error.extension.forbid_exclude"))
156-
self._excludes = set(excludes or [])
157-
self.extensions = [
158-
ext
159-
for ext in self.extensions
160-
if ext.id not in self._excludes
161-
and ext.__class__ not in self._excludes
162-
and (not (ns := ext.namespace) or ns == rule._namespace)
163-
]
164-
self.context: list[Extension] = []
165-
self._rule = rule
166-
167-
_callbacks.add(self._callback)
168-
169-
finalize(self, _callbacks.remove, self._callback)
170-
171-
def _callback(self, *append_global_ext: type[Extension] | Extension):
172-
for _ext in append_global_ext:
173-
if isinstance(_ext, type):
174-
_ext = _ext()
175-
if _ext.id in self._excludes or _ext.__class__ in self._excludes:
176-
continue
177-
if (ns := _ext.namespace) and ns != self._rule._namespace:
178-
continue
179-
self.extensions.append(_ext)
180-
_ext.post_init(self._rule.command()) # type: ignore
181-
182-
def __enter__(self) -> Self:
183-
return self
184-
185-
def __exit__(self, exc_type, exc_value, traceback):
186-
self.clear()
187-
188-
def select(self, bot: Bot, event: Event) -> Self:
189-
self.context = [ext for ext in self.extensions if ext.validate(bot, event)]
190-
self.context.sort(key=lambda ext: ext.priority)
191-
return self
192-
193-
def clear(self) -> None:
194-
self.context.clear()
195-
196-
async def output_converter(self, output_type: OutputType, content: str) -> UniMessage:
197-
exc = None
198-
for ext in self.context:
199-
if not ext._overrides["output_converter"]:
200-
continue
201-
try:
202-
return await ext.output_converter(output_type, content)
203-
except Exception as e:
204-
exc = e
205-
if not exc:
206-
return UniMessage()
207-
raise exc # type: ignore
131+
@dataclass
132+
class SelectedExtensions:
133+
context: list[Extension]
208134

209135
async def message_provider(
210136
self, event: Event, state: T_State, bot: Bot, use_origin: bool = False
@@ -269,13 +195,82 @@ async def parse_wrapper(self, bot: Bot, state: T_State, event: Event, res: Arpar
269195
*(ext.parse_wrapper(bot, state, event, res) for ext in self.context if ext._overrides["parse_wrapper"])
270196
)
271197

198+
async def output_converter(self, output_type: OutputType, content: str) -> UniMessage:
199+
exc = None
200+
for ext in self.context:
201+
if not ext._overrides["output_converter"]:
202+
continue
203+
try:
204+
return await ext.output_converter(output_type, content)
205+
except Exception as e:
206+
exc = e
207+
if not exc:
208+
return UniMessage()
209+
raise exc # type: ignore
210+
272211
async def send_wrapper(self, bot: Bot, event: Event, send: TM) -> TM:
273212
res = send
274213
for ext in self.context:
275214
if ext._overrides["send_wrapper"]:
276215
res = await ext.send_wrapper(bot, event, res)
277216
return res
278217

218+
219+
class ExtensionExecutor(SelectedExtensions):
220+
globals: ClassVar[list[type[Extension] | Extension]] = [DefaultExtension()]
221+
_rule: AlconnaRule
222+
223+
def __init__(
224+
self,
225+
rule: AlconnaRule,
226+
extensions: list[type[Extension] | Extension] | None = None,
227+
excludes: list[str | type[Extension]] | None = None,
228+
):
229+
self.extensions: list[Extension] = []
230+
for ext in self.globals:
231+
if isinstance(ext, type):
232+
self.extensions.append(ext())
233+
else:
234+
self.extensions.append(ext)
235+
for ext in extensions or []:
236+
if isinstance(ext, type):
237+
self.extensions.append(ext())
238+
else:
239+
self.extensions.append(ext)
240+
for exl in excludes or []:
241+
if isinstance(exl, str) and exl.startswith("!"):
242+
raise ValueError(lang.require("nbp-alc", "error.extension.forbid_exclude"))
243+
self._excludes = set(excludes or [])
244+
self.extensions = [
245+
ext
246+
for ext in self.extensions
247+
if ext.id not in self._excludes
248+
and ext.__class__ not in self._excludes
249+
and (not (ns := ext.namespace) or ns == rule._namespace)
250+
]
251+
self.context = self.extensions
252+
self._rule = rule
253+
254+
_callbacks.add(self._callback)
255+
256+
finalize(self, _callbacks.remove, self._callback)
257+
258+
def _callback(self, *append_global_ext: type[Extension] | Extension):
259+
for _ext in append_global_ext:
260+
if isinstance(_ext, type):
261+
_ext = _ext()
262+
if _ext.id in self._excludes or _ext.__class__ in self._excludes:
263+
continue
264+
if (ns := _ext.namespace) and ns != self._rule._namespace:
265+
continue
266+
self.extensions.append(_ext)
267+
_ext.post_init(self._rule.command()) # type: ignore
268+
269+
def select(self, bot: Bot, event: Event) -> SelectedExtensions:
270+
context = [ext for ext in self.extensions if ext.validate(bot, event)]
271+
context.sort(key=lambda ext: ext.priority)
272+
return SelectedExtensions(context)
273+
279274
def before_catch(self, name: str, annotation: Any, default: Any) -> bool:
280275
for ext in self.extensions:
281276
if ext._overrides["catch"]:

src/nonebot_plugin_alconna/matcher.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from tarina.lang.model import LangItem
1818
from nonebot.permission import Permission
1919
from nonebot.dependencies import Dependent
20-
from nonebot.message import run_postprocessor
2120
from arclet.alconna.tools import AlconnaFormat
2221
from nonebot.consts import ARG_KEY, RECEIVE_KEY
2322
from nonebot.internal.params import DefaultParam
@@ -1074,9 +1073,3 @@ def referent(cmd: str | Alconna | None) -> type[AlconnaMatcher] | None:
10741073
except KeyError:
10751074
return None
10761075
return None
1077-
1078-
1079-
@run_postprocessor
1080-
@annotation(matcher=AlconnaMatcher)
1081-
def _exit_executor(matcher: AlconnaMatcher):
1082-
matcher.executor.clear()

src/nonebot_plugin_alconna/params.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from .typings import CHECK, MIDDLEWARE
1818
from .model import T, Match, Query, CommandResult
19-
from .extension import Extension, ExtensionExecutor
19+
from .extension import Extension, ExtensionExecutor, SelectedExtensions
2020
from .consts import ALCONNA_RESULT, ALCONNA_ARG_KEY, ALCONNA_ARG_PATH, ALCONNA_EXTENSION, ALCONNA_EXEC_RESULT
2121

2222
T_Duplication = TypeVar("T_Duplication", bound=Duplication)
@@ -204,8 +204,8 @@ async def _arparma_check(bot: Bot, state: T_State, event: Event, matcher: Matche
204204

205205
def AlconnaExtension(target: type[T_Extension]) -> T_Extension:
206206
def _alconna_extension(state: T_State):
207-
exts = state[ALCONNA_EXTENSION]
208-
return next((i for i in exts if isinstance(i, target)), None) # type: ignore
207+
selected: SelectedExtensions = state[ALCONNA_EXTENSION]
208+
return next((i for i in selected.context if isinstance(i, target)), None)
209209

210210
return Depends(_alconna_extension, use_cache=False)
211211

@@ -261,6 +261,8 @@ def _check_param(cls, param: inspect.Parameter, allow_types: tuple[type[Param],
261261
return cls(..., type=Alconna)
262262
if annotation is Duplication:
263263
return cls(..., type=Duplication)
264+
if annotation is SelectedExtensions:
265+
return cls(..., type=SelectedExtensions)
264266
if inspect.isclass(annotation) and issubclass(annotation, Duplication):
265267
return cls(..., anno=param.annotation, type=Duplication)
266268
if inspect.isclass(annotation) and issubclass(annotation, Extension):
@@ -288,9 +290,11 @@ async def _solve(self, matcher: Matcher, event: Event, state: T_State, **kwargs:
288290
if anno := self.extra.get("anno"):
289291
return anno(res.result)
290292
return generate_duplication(res.source)(res.result)
293+
if t is SelectedExtensions:
294+
return state[ALCONNA_EXTENSION]
291295
if t is Extension:
292296
anno = self.extra["anno"]
293-
return next((i for i in state[ALCONNA_EXTENSION] if isinstance(i, anno)), None) # type: ignore
297+
return next((i for i in state[ALCONNA_EXTENSION].context if isinstance(i, anno)), None) # type: ignore
294298
if t is Match:
295299
target = res.result.all_matched_args.get(self.extra["name"], Empty)
296300
return Match(target, target != Empty)
@@ -320,18 +324,6 @@ async def _solve(self, matcher: Matcher, event: Event, state: T_State, **kwargs:
320324
return result
321325
return self.default if self.default not in (..., Empty) else PydanticUndefined
322326

323-
# async def _check(self, state: T_State, **kwargs: Any) -> Any:
324-
# if self.extra["type"] == Any:
325-
# if (
326-
# self.extra["name"] in _alconna_result(state).result.all_matched_args
327-
# or ALCONNA_ARG_KEY.format(key=self.extra["name"]) in state
328-
# or ((path := state.get(ALCONNA_ARG_PATH)) and path.endswith(f".{self.extra['name']}"))
329-
# ):
330-
# return True
331-
# if self.default not in (..., Empty):
332-
# return True
333-
# return False
334-
335327

336328
class _Dispatch:
337329
def __init__(

src/nonebot_plugin_alconna/rule.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .uniseg import UniMsg, UniMessage
2020
from .model import CompConfig, CommandResult
2121
from .uniseg.constraint import UNISEG_MESSAGE
22-
from .extension import Extension, ExtensionExecutor
22+
from .extension import Extension, ExtensionExecutor, SelectedExtensions
2323
from .consts import ALCONNA_RESULT, ALCONNA_EXTENSION, ALCONNA_EXEC_RESULT, log
2424

2525
try:
@@ -212,9 +212,9 @@ def __hash__(self) -> int:
212212
return hash(self.command.__hash__())
213213

214214
async def handle(
215-
self, cmd: Alconna, bot: Bot, event: Event, state: T_State, msg: UniMessage
215+
self, selected: SelectedExtensions, cmd: Alconna, bot: Bot, event: Event, state: T_State, msg: UniMessage
216216
) -> Union[Arparma, Literal[False]]:
217-
ctx = await self.executor.context_provider(event, bot, state)
217+
ctx = await selected.context_provider(event, bot, state)
218218
try:
219219
session_id = event.get_session_id()
220220
except ValueError:
@@ -228,7 +228,7 @@ async def handle(
228228
if res:
229229
interface.exit()
230230
return res
231-
if not await self.executor.permission_check(bot, event, cmd):
231+
if not await selected.permission_check(bot, event, cmd):
232232
return False
233233

234234
res = Arparma(
@@ -271,12 +271,10 @@ def _checker(_event: Event):
271271
return res
272272

273273
async def __call__(self, event: Event, state: T_State, bot: Bot) -> bool:
274-
self.executor.select(bot, event)
275-
if not (msg := await self.executor.message_provider(event, state, bot, self.use_origin)):
276-
self.executor.clear()
274+
selected = self.executor.select(bot, event)
275+
if not (msg := await selected.message_provider(event, state, bot, self.use_origin)):
277276
return False
278277
if not self.response_self and check_self_send(bot, event):
279-
self.executor.clear()
280278
return False
281279
try:
282280
session_id = event.get_session_id()
@@ -286,38 +284,33 @@ async def __call__(self, event: Event, state: T_State, bot: Bot) -> bool:
286284
await self._tasks[session_id]
287285
cmd = self.command()
288286
if not cmd:
289-
self.executor.clear()
290287
return False
291288
if command_manager.is_disable(cmd):
292-
self.executor.clear()
293289
return False
294-
msg = await self.executor.receive_wrapper(bot, event, cmd, msg)
290+
msg = await selected.receive_wrapper(bot, event, cmd, msg)
295291
Arparma._additional.update(bot=lambda: bot, event=lambda: event, state=lambda: state)
296292
state[UNISEG_MESSAGE] = msg
297293

298294
with output_manager.capture(cmd.name) as cap:
299295
output_manager.set_action(lambda x: x, cmd.name)
300-
task = asyncio.create_task(self.handle(cmd, bot, event, state, msg))
296+
task = asyncio.create_task(self.handle(selected, cmd, bot, event, state, msg))
301297
if session_id:
302298
self._tasks[session_id] = task
303299
task.add_done_callback(lambda _: self._tasks.pop(session_id, None))
304300
try:
305301
arp = await task
306302
if arp is False:
307-
self.executor.clear()
308303
return False
309304
except Exception as e:
310305
arp = Arparma(cmd._hash, msg, False, error_info=e)
311306
may_help_text: Optional[str] = cap.get("output", None)
312307
if not arp.head_matched:
313-
self.executor.clear()
314308
return False
315309
if not arp.matched and not may_help_text and self.skip:
316310
log(
317311
"TRACE",
318312
escape_tag(Lang.nbp_alc.log.parse(msg=msg, cmd=self._path, arp=arp)),
319313
)
320-
self.executor.clear()
321314
return False
322315
if arp.head_matched:
323316
log(
@@ -328,18 +321,15 @@ async def __call__(self, event: Event, state: T_State, bot: Bot) -> bool:
328321
may_help_text = str(arp.error_info)
329322
if self.auto_send and may_help_text:
330323
await self.send(may_help_text, bot, event, arp)
331-
self.executor.clear()
332324
return False
333325
if self.skip and may_help_text:
334-
self.executor.clear()
335326
return False
336-
if not await self.executor.permission_check(bot, event, cmd):
337-
self.executor.clear()
327+
if not await selected.permission_check(bot, event, cmd):
338328
return False
339-
await self.executor.parse_wrapper(bot, state, event, arp)
329+
await selected.parse_wrapper(bot, state, event, arp)
340330
state[ALCONNA_RESULT] = CommandResult(result=arp, output=may_help_text)
341331
state[ALCONNA_EXEC_RESULT] = cmd.exec_result
342-
state[ALCONNA_EXTENSION] = self.executor.context
332+
state[ALCONNA_EXTENSION] = selected
343333
return True
344334

345335
async def send(self, text: str, bot: Bot, event: Event, arp: Arparma) -> Any:

src/nonebot_plugin_alconna/uniseg/adapters/kook/exporter.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,18 +203,18 @@ async def reaction(
203203
if isinstance(context, Target):
204204
if context.private:
205205
if delete:
206-
return await bot.directMessage_deleteReaction(msg_id=msg_id, emoji=emoji.name or emoji.id)
207-
return await bot.directMessage_addReaction(msg_id=msg_id, emoji=emoji.name or emoji.id)
206+
return await bot.directMessage_deleteReaction(msg_id=msg_id, emoji=emoji.id)
207+
return await bot.directMessage_addReaction(msg_id=msg_id, emoji=emoji.id)
208208
if delete:
209-
return await bot.message_deleteReaction(msg_id=msg_id, emoji=emoji.name or emoji.id)
210-
return await bot.message_addReaction(msg_id=msg_id, emoji=emoji.name or emoji.id)
209+
return await bot.message_deleteReaction(msg_id=msg_id, emoji=emoji.id)
210+
return await bot.message_addReaction(msg_id=msg_id, emoji=emoji.id)
211211
if isinstance(context, PrivateMessageEvent):
212212
if delete:
213-
return await bot.directMessage_deleteReaction(msg_id=msg_id, emoji=emoji.name or emoji.id)
214-
return await bot.directMessage_addReaction(msg_id=msg_id, emoji=emoji.name or emoji.id)
213+
return await bot.directMessage_deleteReaction(msg_id=msg_id, emoji=emoji.id)
214+
return await bot.directMessage_addReaction(msg_id=msg_id, emoji=emoji.id)
215215
if delete:
216-
return await bot.message_deleteReaction(msg_id=msg_id, emoji=emoji.name or emoji.id)
217-
return await bot.message_addReaction(msg_id=msg_id, emoji=emoji.name or emoji.id)
216+
return await bot.message_deleteReaction(msg_id=msg_id, emoji=emoji.id)
217+
return await bot.message_addReaction(msg_id=msg_id, emoji=emoji.id)
218218

219219
def get_reply(self, mid: Any):
220220
_mid: MessageCreateReturn = cast(MessageCreateReturn, mid)

0 commit comments

Comments
 (0)