22
33from abc import ABCMeta , abstractmethod
44import asyncio
5+ from contextlib import AsyncExitStack
56from dataclasses import dataclass
67import functools
78import importlib as imp
9+ import inspect
810import re
9- from typing import TYPE_CHECKING , Any , ClassVar , Generic , Literal , TypeVar , Union
11+ from typing import TYPE_CHECKING , Any , ClassVar , Generic , Literal , TypeVar , Union , final , overload
1012from weakref import finalize
1113
1214from arclet .alconna import Alconna , Arparma
1315from nonebot import get_plugin_config
1416from nonebot .adapters import Bot , Event , Message
1517from nonebot .compat import PydanticUndefined
16- from nonebot .typing import T_State
18+ from nonebot .dependencies import Dependent , Param
19+ from nonebot .internal .params import DependencyCache , DependParam , DependsInner
20+ from nonebot .typing import T_State , _DependentCallable
21+ from pydantic .fields import FieldInfo
1722from tarina import LRU , lang
1823
1924from .config import Config
2025from .uniseg import UniMessage , get_message_id
2126
2227OutputType = Literal ["help" , "shortcut" , "completion" , "error" ]
28+ T = TypeVar ("T" )
2329TM = TypeVar ("TM" , bound = Union [str , Message , UniMessage ])
2430TE = TypeVar ("TE" , bound = Event )
2531
@@ -57,6 +63,8 @@ def __init_subclass__(cls, **kwargs):
5763 "catch" : cls .catch != Extension .catch and cls .before_catch != Extension .before_catch ,
5864 }
5965
66+ executor : ExtensionExecutor
67+
6068 @property
6169 @abstractmethod
6270 def priority (self ) -> int :
@@ -77,6 +85,40 @@ def namespace(self) -> str:
7785 def validate (self , bot : Bot , event : Event ) -> bool :
7886 return event .get_type () == "message"
7987
88+ @overload
89+ async def inject (
90+ self , dependent : Dependent [T ], * , use_cache : bool = True , validate : bool | FieldInfo = False
91+ ) -> T : ...
92+
93+ @overload
94+ async def inject (self , dependent : tuple [str , type [T ]]) -> T : ...
95+
96+ @overload
97+ async def inject (self , dependent : Any ) -> Any : ...
98+
99+ @final
100+ async def inject (self , dependent : Any , use_cache : bool = True , validate : bool | FieldInfo = False ) -> Any :
101+ # assert isinstance(dependent, (Dependent, DependsInner)), "仅支持 Dependent 或 DependsInner 类型的依赖注入"
102+ if isinstance (dependent , DependsInner ):
103+ if not dependent .dependency :
104+ raise ValueError ("DependsInner 未绑定任何依赖" )
105+ use_cache = dependent .use_cache
106+ validate = dependent .validate
107+ dependent = Dependent .parse (call = dependent .dependency , allow_types = self .executor .params )
108+ param = DependParam (dependent = dependent , use_cache = use_cache , validate = validate )
109+ elif isinstance (dependent , Dependent ):
110+ param = DependParam (dependent = dependent , use_cache = use_cache , validate = validate )
111+ else :
112+ for allow_type in self .executor .params :
113+ if param := allow_type ._check_param (
114+ inspect .Parameter (dependent [0 ], inspect .Parameter .POSITIONAL_OR_KEYWORD , annotation = dependent [1 ]),
115+ self .executor .params ,
116+ ):
117+ break
118+ else :
119+ raise ValueError (f"Unknown parameter { dependent [0 ]} with type { dependent [1 ]} " )
120+ return await self .executor ._dependent_executor (param )
121+
80122 async def output_converter (self , output_type : OutputType , content : str ) -> UniMessage :
81123 """依据输出信息的类型,将字符串转换为消息对象以便发送。"""
82124 return UniMessage (content )
@@ -224,16 +266,39 @@ async def send_wrapper(self, bot: Bot, event: Event, send: TM) -> TM:
224266 return res
225267
226268
269+ class _DependentExecutor :
270+ def __init__ (
271+ self ,
272+ bot : Bot ,
273+ event : Event ,
274+ state : T_State ,
275+ stack : AsyncExitStack | None = None ,
276+ dependency_cache : dict [_DependentCallable [Any ], DependencyCache ] | None = None ,
277+ ):
278+ self .bot = bot
279+ self .event = event
280+ self .state = state
281+ self .stack = stack
282+ self .dependency_cache = dependency_cache or {}
283+
284+ async def __call__ (self , param : Param ):
285+ return await param ._solve (
286+ stack = self .stack , dependency_cache = self .dependency_cache , bot = self .bot , event = self .event , state = self .state
287+ )
288+
289+
227290class ExtensionExecutor (SelectedExtensions ):
228291 globals : ClassVar [list [type [Extension ] | Extension ]] = [DefaultExtension ()]
229292 _rule : AlconnaRule
293+ _dependent_executor : _DependentExecutor
230294
231295 def __init__ (
232296 self ,
233297 rule : AlconnaRule ,
234298 extensions : list [type [Extension ] | Extension ] | None = None ,
235299 excludes : list [str | type [Extension ]] | None = None ,
236300 ):
301+ self .params : tuple [type [Param ], ...] = ()
237302 self .extensions : list [Extension ] = []
238303 for ext in self .globals :
239304 if isinstance (ext , type ):
@@ -284,6 +349,8 @@ def _callback(self, *append_global_ext: type[Extension] | Extension):
284349 def select (self , bot : Bot , event : Event ) -> SelectedExtensions :
285350 context = [ext for ext in self .extensions if ext .validate (bot , event )]
286351 context .sort (key = lambda ext : ext .priority )
352+ for ext in context :
353+ ext .executor = self
287354 return SelectedExtensions (context )
288355
289356 def before_catch (self , name : str , annotation : Any , default : Any ) -> bool :
0 commit comments