|
| 1 | +from contextlib import AsyncExitStack |
1 | 2 | import inspect |
2 | 3 | from typing import Annotated, Any, ClassVar, Literal, Optional, TypeVar, Union, overload |
3 | | -from typing_extensions import get_args |
| 4 | +from typing_extensions import Self, get_args, override |
4 | 5 |
|
5 | 6 | from arclet.alconna import Alconna, Arparma, Duplication, Empty |
6 | 7 | from arclet.alconna.builtin import generate_duplication |
|
10 | 11 | from nonebot.internal.adapter import Bot, Event |
11 | 12 | from nonebot.internal.matcher import Matcher |
12 | 13 | from nonebot.internal.params import Depends |
13 | | -from nonebot.typing import T_State |
| 14 | +from nonebot.typing import T_DependencyCache, T_State |
| 15 | +from nonebot.utils import generic_check_issubclass |
14 | 16 | from tarina import run_always_await |
15 | 17 | from tarina.generic import get_origin |
16 | 18 |
|
@@ -349,3 +351,47 @@ async def __call__(self, _state: T_State, event: Event, bot: Bot) -> bool: |
349 | 351 | self.result = None |
350 | 352 | return True |
351 | 353 | return False |
| 354 | + |
| 355 | + |
| 356 | +class StackParam(Param): |
| 357 | + """上下文栈注入参数。 |
| 358 | +
|
| 359 | + 本注入解析 AsyncExitStack 实例,用于在依赖注入中管理异步上下文。 |
| 360 | + """ |
| 361 | + |
| 362 | + def __repr__(self) -> str: |
| 363 | + return "_StackParam()" |
| 364 | + |
| 365 | + @classmethod |
| 366 | + @override |
| 367 | + def _check_param(cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]) -> Optional[Self]: |
| 368 | + if param.annotation == AsyncExitStack: |
| 369 | + return cls(..., type=AsyncExitStack) |
| 370 | + if generic_check_issubclass(param.annotation, AsyncExitStack): |
| 371 | + return cls(..., type=AsyncExitStack, default=None) |
| 372 | + |
| 373 | + @override |
| 374 | + async def _solve(self, stack: Optional[AsyncExitStack] = None, **kwargs: Any) -> Any: |
| 375 | + return stack |
| 376 | + |
| 377 | + |
| 378 | +class DependencyCacheParam(Param): |
| 379 | + """依赖缓存注入参数。 |
| 380 | +
|
| 381 | + 本注入解析 T_DependencyCache 实例,用于在依赖注入中管理依赖缓存。 |
| 382 | + """ |
| 383 | + |
| 384 | + def __repr__(self) -> str: |
| 385 | + return "_DependencyCacheParam()" |
| 386 | + |
| 387 | + @classmethod |
| 388 | + @override |
| 389 | + def _check_param(cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]) -> Optional[Self]: |
| 390 | + if param.annotation == T_DependencyCache: |
| 391 | + return cls(..., type=T_DependencyCache) |
| 392 | + if generic_check_issubclass(param.annotation, T_DependencyCache): |
| 393 | + return cls(..., type=T_DependencyCache, default=None) |
| 394 | + |
| 395 | + @override |
| 396 | + async def _solve(self, dependency_cache: Optional[T_DependencyCache] = None, **kwargs: Any) -> Any: |
| 397 | + return dependency_cache |
0 commit comments