1818from pydantic .fields import FieldInfo as PydanticFieldInfo
1919
2020from nonebot .compat import FieldInfo , ModelField , PydanticUndefined , extract_field_info
21+ from nonebot .consts import ARG_KEY , REJECT_PROMPT_RESULT_KEY
2122from nonebot .dependencies import Dependent , Param
2223from nonebot .dependencies .utils import check_field_type
2324from nonebot .exception import SkippedException
3940)
4041
4142if TYPE_CHECKING :
42- from nonebot .adapters import Bot , Event
43+ from nonebot .adapters import Bot , Event , Message
4344 from nonebot .matcher import Matcher
4445
4546
@@ -522,10 +523,10 @@ async def _check( # pyright: ignore[reportIncompatibleMethodOverride]
522523
523524class ArgInner :
524525 def __init__ (
525- self , key : Optional [str ], type : Literal ["message" , "str" , "plaintext" ]
526+ self , key : Optional [str ], type : Literal ["message" , "str" , "plaintext" , "prompt" ]
526527 ) -> None :
527528 self .key : Optional [str ] = key
528- self .type : Literal ["message" , "str" , "plaintext" ] = type
529+ self .type : Literal ["message" , "str" , "plaintext" , "prompt" ] = type
529530
530531 def __repr__ (self ) -> str :
531532 return f"ArgInner(key={ self .key !r} , type={ self .type !r} )"
@@ -546,6 +547,11 @@ def ArgPlainText(key: Optional[str] = None) -> str:
546547 return ArgInner (key , "plaintext" ) # type: ignore
547548
548549
550+ def ArgPromptResult (key : Optional [str ] = None ) -> Any :
551+ """`arg` prompt 发送结果"""
552+ return ArgInner (key , "prompt" )
553+
554+
549555class ArgParam (Param ):
550556 """Arg 注入参数
551557
@@ -559,7 +565,7 @@ def __init__(
559565 self ,
560566 * args ,
561567 key : str ,
562- type : Literal ["message" , "str" , "plaintext" ],
568+ type : Literal ["message" , "str" , "plaintext" , "prompt" ],
563569 ** kwargs : Any ,
564570 ) -> None :
565571 super ().__init__ (* args , ** kwargs )
@@ -584,15 +590,32 @@ def _check_param(
584590 async def _solve ( # pyright: ignore[reportIncompatibleMethodOverride]
585591 self , matcher : "Matcher" , ** kwargs : Any
586592 ) -> Any :
587- message = matcher .get_arg (self .key )
588- if message is None :
589- return message
590593 if self .type == "message" :
591- return message
594+ return self . _solve_message ( matcher )
592595 elif self .type == "str" :
593- return str (message )
596+ return self ._solve_str (matcher )
597+ elif self .type == "plaintext" :
598+ return self ._solve_plaintext (matcher )
599+ elif self .type == "prompt" :
600+ return self ._solve_prompt (matcher )
594601 else :
595- return message .extract_plain_text ()
602+ raise ValueError (f"Unknown Arg type: { self .type } " )
603+
604+ def _solve_message (self , matcher : "Matcher" ) -> Optional ["Message" ]:
605+ return matcher .get_arg (self .key )
606+
607+ def _solve_str (self , matcher : "Matcher" ) -> Optional [str ]:
608+ message = matcher .get_arg (self .key )
609+ return str (message ) if message is not None else None
610+
611+ def _solve_plaintext (self , matcher : "Matcher" ) -> Optional [str ]:
612+ message = matcher .get_arg (self .key )
613+ return message .extract_plain_text () if message is not None else None
614+
615+ def _solve_prompt (self , matcher : "Matcher" ) -> Optional [Any ]:
616+ return matcher .state .get (
617+ REJECT_PROMPT_RESULT_KEY .format (key = ARG_KEY .format (key = self .key ))
618+ )
596619
597620
598621class ExceptionParam (Param ):
0 commit comments