Skip to content

Commit 32bc2c3

Browse files
authored
✨ Feature: 存储 matcher 发送 prompt 的结果 (#3155)
1 parent ab8dea5 commit 32bc2c3

File tree

8 files changed

+271
-22
lines changed

8 files changed

+271
-22
lines changed

nonebot/consts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
"""当前 `reject` 目标存储 key"""
2323
REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target"
2424
"""下一个 `reject` 目标存储 key"""
25+
PAUSE_PROMPT_RESULT_KEY: Literal["_pause_result"] = "_pause_result"
26+
"""`pause` prompt 发送结果存储 key"""
27+
REJECT_PROMPT_RESULT_KEY: Literal["_reject_{key}_result"] = "_reject_{key}_result"
28+
"""`reject` prompt 发送结果存储 key"""
2529

2630
# used by Rule
2731
PREFIX_KEY: Literal["_prefix"] = "_prefix"

nonebot/internal/matcher/matcher.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
from nonebot.consts import (
2828
ARG_KEY,
2929
LAST_RECEIVE_KEY,
30+
PAUSE_PROMPT_RESULT_KEY,
3031
RECEIVE_KEY,
3132
REJECT_CACHE_TARGET,
33+
REJECT_PROMPT_RESULT_KEY,
3234
REJECT_TARGET,
3335
)
3436
from nonebot.dependencies import Dependent, Param
@@ -560,8 +562,8 @@ async def send(
560562
"""
561563
bot = current_bot.get()
562564
event = current_event.get()
563-
state = current_matcher.get().state
564565
if isinstance(message, MessageTemplate):
566+
state = current_matcher.get().state
565567
_message = message.format(**state)
566568
else:
567569
_message = message
@@ -597,8 +599,15 @@ async def pause(
597599
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
598600
请参考对应 adapter 的 bot 对象 api
599601
"""
602+
try:
603+
matcher = current_matcher.get()
604+
except Exception:
605+
matcher = None
606+
600607
if prompt is not None:
601-
await cls.send(prompt, **kwargs)
608+
result = await cls.send(prompt, **kwargs)
609+
if matcher is not None:
610+
matcher.state[PAUSE_PROMPT_RESULT_KEY] = result
602611
raise PausedException
603612

604613
@classmethod
@@ -615,8 +624,19 @@ async def reject(
615624
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
616625
请参考对应 adapter 的 bot 对象 api
617626
"""
627+
try:
628+
matcher = current_matcher.get()
629+
key = matcher.get_target()
630+
except Exception:
631+
matcher = None
632+
key = None
633+
634+
key = REJECT_PROMPT_RESULT_KEY.format(key=key) if key is not None else None
635+
618636
if prompt is not None:
619-
await cls.send(prompt, **kwargs)
637+
result = await cls.send(prompt, **kwargs)
638+
if key is not None and matcher:
639+
matcher.state[key] = result
620640
raise RejectedException
621641

622642
@classmethod
@@ -636,9 +656,12 @@ async def reject_arg(
636656
请参考对应 adapter 的 bot 对象 api
637657
"""
638658
matcher = current_matcher.get()
639-
matcher.set_target(ARG_KEY.format(key=key))
659+
arg_key = ARG_KEY.format(key=key)
660+
matcher.set_target(arg_key)
661+
640662
if prompt is not None:
641-
await cls.send(prompt, **kwargs)
663+
result = await cls.send(prompt, **kwargs)
664+
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=arg_key)] = result
642665
raise RejectedException
643666

644667
@classmethod
@@ -658,9 +681,12 @@ async def reject_receive(
658681
请参考对应 adapter 的 bot 对象 api
659682
"""
660683
matcher = current_matcher.get()
661-
matcher.set_target(RECEIVE_KEY.format(id=id))
684+
receive_key = RECEIVE_KEY.format(id=id)
685+
matcher.set_target(receive_key)
686+
662687
if prompt is not None:
663-
await cls.send(prompt, **kwargs)
688+
result = await cls.send(prompt, **kwargs)
689+
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=receive_key)] = result
664690
raise RejectedException
665691

666692
@classmethod

nonebot/internal/params.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pydantic.fields import FieldInfo as PydanticFieldInfo
1919

2020
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
21+
from nonebot.consts import ARG_KEY, REJECT_PROMPT_RESULT_KEY
2122
from nonebot.dependencies import Dependent, Param
2223
from nonebot.dependencies.utils import check_field_type
2324
from nonebot.exception import SkippedException
@@ -39,7 +40,7 @@
3940
)
4041

4142
if TYPE_CHECKING:
42-
from nonebot.adapters import Bot, Event
43+
from nonebot.adapters import Bot, Event, Message
4344
from nonebot.matcher import Matcher
4445

4546

@@ -522,10 +523,10 @@ async def _check( # pyright: ignore[reportIncompatibleMethodOverride]
522523

523524
class ArgInner:
524525
def __init__(
525-
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
526+
self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"]
526527
) -> None:
527528
self.key: Optional[str] = key
528-
self.type: Literal["message", "str", "plaintext"] = type
529+
self.type: Literal["message", "str", "plaintext", "prompt"] = type
529530

530531
def __repr__(self) -> str:
531532
return f"ArgInner(key={self.key!r}, type={self.type!r})"
@@ -546,6 +547,11 @@ def ArgPlainText(key: Optional[str] = None) -> str:
546547
return ArgInner(key, "plaintext") # type: ignore
547548

548549

550+
def ArgPromptResult(key: Optional[str] = None) -> Any:
551+
"""`arg` prompt 发送结果"""
552+
return ArgInner(key, "prompt")
553+
554+
549555
class ArgParam(Param):
550556
"""Arg 注入参数
551557
@@ -559,7 +565,7 @@ def __init__(
559565
self,
560566
*args,
561567
key: str,
562-
type: Literal["message", "str", "plaintext"],
568+
type: Literal["message", "str", "plaintext", "prompt"],
563569
**kwargs: Any,
564570
) -> None:
565571
super().__init__(*args, **kwargs)
@@ -584,15 +590,32 @@ def _check_param(
584590
async def _solve( # pyright: ignore[reportIncompatibleMethodOverride]
585591
self, matcher: "Matcher", **kwargs: Any
586592
) -> Any:
587-
message = matcher.get_arg(self.key)
588-
if message is None:
589-
return message
590593
if self.type == "message":
591-
return message
594+
return self._solve_message(matcher)
592595
elif self.type == "str":
593-
return str(message)
596+
return self._solve_str(matcher)
597+
elif self.type == "plaintext":
598+
return self._solve_plaintext(matcher)
599+
elif self.type == "prompt":
600+
return self._solve_prompt(matcher)
594601
else:
595-
return message.extract_plain_text()
602+
raise ValueError(f"Unknown Arg type: {self.type}")
603+
604+
def _solve_message(self, matcher: "Matcher") -> Optional["Message"]:
605+
return matcher.get_arg(self.key)
606+
607+
def _solve_str(self, matcher: "Matcher") -> Optional[str]:
608+
message = matcher.get_arg(self.key)
609+
return str(message) if message is not None else None
610+
611+
def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]:
612+
message = matcher.get_arg(self.key)
613+
return message.extract_plain_text() if message is not None else None
614+
615+
def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]:
616+
return matcher.state.get(
617+
REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key))
618+
)
596619

597620

598621
class ExceptionParam(Param):

nonebot/params.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@
1919
ENDSWITH_KEY,
2020
FULLMATCH_KEY,
2121
KEYWORD_KEY,
22+
PAUSE_PROMPT_RESULT_KEY,
2223
PREFIX_KEY,
2324
RAW_CMD_KEY,
25+
RECEIVE_KEY,
2426
REGEX_MATCHED,
27+
REJECT_PROMPT_RESULT_KEY,
2528
SHELL_ARGS,
2629
SHELL_ARGV,
2730
STARTSWITH_KEY,
2831
)
2932
from nonebot.internal.params import Arg as Arg
3033
from nonebot.internal.params import ArgParam as ArgParam
3134
from nonebot.internal.params import ArgPlainText as ArgPlainText
35+
from nonebot.internal.params import ArgPromptResult as ArgPromptResult
3236
from nonebot.internal.params import ArgStr as ArgStr
3337
from nonebot.internal.params import BotParam as BotParam
3438
from nonebot.internal.params import DefaultParam as DefaultParam
@@ -252,6 +256,26 @@ def _last_received(matcher: "Matcher") -> Any:
252256
return Depends(_last_received, use_cache=False)
253257

254258

259+
def ReceivePromptResult(id: Optional[str] = None) -> Any:
260+
"""`receive` prompt 发送结果"""
261+
262+
def _receive_prompt_result(matcher: "Matcher") -> Any:
263+
return matcher.state.get(
264+
REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id=id))
265+
)
266+
267+
return Depends(_receive_prompt_result, use_cache=False)
268+
269+
270+
def PausePromptResult() -> Any:
271+
"""`pause` prompt 发送结果"""
272+
273+
def _pause_prompt_result(matcher: "Matcher") -> Any:
274+
return matcher.state.get(PAUSE_PROMPT_RESULT_KEY)
275+
276+
return Depends(_pause_prompt_result, use_cache=False)
277+
278+
255279
__autodoc__ = {
256280
"Arg": True,
257281
"ArgStr": True,
@@ -265,4 +289,5 @@ def _last_received(matcher: "Matcher") -> Any:
265289
"DefaultParam": True,
266290
"MatcherParam": True,
267291
"ExceptionParam": True,
292+
"ArgPromptResult": True,
268293
}

tests/plugins/param/param_arg.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Annotated
1+
from typing import Annotated, Any
22

33
from nonebot.adapters import Message
4-
from nonebot.params import Arg, ArgPlainText, ArgStr
4+
from nonebot.params import Arg, ArgPlainText, ArgPromptResult, ArgStr
55

66

77
async def arg(key: Message = Arg()) -> Message:
@@ -28,6 +28,10 @@ async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str:
2828
return key
2929

3030

31+
async def annotated_arg_prompt_result(key: Annotated[Any, ArgPromptResult()]) -> Any:
32+
return key
33+
34+
3135
# test dependency priority
3236
async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()):
3337
return key

tests/plugins/param/param_matcher.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
from typing import TypeVar, Union
1+
from typing import Any, TypeVar, Union
22

33
from nonebot.adapters import Event
44
from nonebot.matcher import Matcher
5-
from nonebot.params import LastReceived, Received
5+
from nonebot.params import (
6+
LastReceived,
7+
PausePromptResult,
8+
Received,
9+
ReceivePromptResult,
10+
)
611

712

813
async def matcher(m: Matcher) -> Matcher:
@@ -59,3 +64,11 @@ async def receive(e: Event = Received("test")) -> Event:
5964

6065
async def last_receive(e: Event = LastReceived()) -> Event:
6166
return e
67+
68+
69+
async def receive_prompt_result(result: Any = ReceivePromptResult("test")) -> Any:
70+
return result
71+
72+
73+
async def pause_prompt_result(result: Any = PausePromptResult()) -> Any:
74+
return result

0 commit comments

Comments
 (0)