|
1 | 1 | import asyncio |
2 | 2 | from uuid import uuid4 |
3 | 3 | from datetime import datetime |
4 | | -from functools import partial |
| 4 | +from abc import ABC, abstractmethod |
5 | 5 | from asyncio import Queue as BaseQueue |
6 | | -from abc import ABC, ABCMeta, abstractmethod |
7 | 6 | from typing import ( |
8 | 7 | TYPE_CHECKING, |
9 | 8 | Any, |
10 | | - Set, |
11 | 9 | Dict, |
12 | 10 | List, |
13 | | - Type, |
14 | 11 | Union, |
15 | 12 | Generic, |
16 | 13 | Literal, |
17 | 14 | TypeVar, |
18 | | - ClassVar, |
19 | 15 | Optional, |
20 | 16 | ) |
21 | 17 |
|
|
44 | 40 | } |
45 | 41 |
|
46 | 42 |
|
47 | | -class supported_action: |
48 | | - def __init__(self, fn): |
49 | | - self.fn = fn |
50 | | - |
51 | | - def __set_name__(self, owner: Type["Middleware"], name: str): |
52 | | - owner.supported_actions.add(name) |
53 | | - |
54 | | - def __call__(self, *args, **kwargs): |
55 | | - return self.fn(*args, **kwargs) |
56 | | - |
57 | | - def __get__(self, obj, objtype=None): |
58 | | - return partial(self.__call__, obj) |
59 | | - |
60 | | - |
61 | 43 | _T = TypeVar("_T", bound=OneBotEvent) |
62 | 44 | if TYPE_CHECKING: |
63 | 45 |
|
@@ -88,37 +70,39 @@ async def get(self): |
88 | 70 | return event |
89 | 71 |
|
90 | 72 |
|
91 | | -class _MiddlewareMeta(type): |
92 | | - def __new__(cls, name, bases, attrs): |
93 | | - supported_actions = set() |
94 | | - for base in bases: |
95 | | - supported_actions.update(base.supported_actions) |
96 | | - attrs["supported_actions"] = supported_actions |
97 | | - return type.__new__(cls, name, bases, attrs) |
98 | | - |
99 | | - |
100 | | -class MiddlewareMeta(_MiddlewareMeta, ABCMeta): |
101 | | - pass |
102 | | - |
| 73 | +def supported_action(func): |
| 74 | + """标记支持的动作""" |
| 75 | + func.__supported__ = True |
| 76 | + return func |
103 | 77 |
|
104 | | -class Middleware(metaclass=MiddlewareMeta): |
105 | | - supported_actions: ClassVar[Set[str]] |
106 | 78 |
|
| 79 | +class Middleware(ABC): |
107 | 80 | def __init__(self, bot: Bot): |
108 | 81 | self.bot = bot |
109 | 82 | self.tasks: List[asyncio.Task] = [] |
110 | 83 | self.queues: List[Queue[OneBotEvent]] = [] |
| 84 | + self._supported_actions = self._get_supported_actions() |
111 | 85 |
|
| 86 | + def _get_supported_actions(self) -> List[str]: |
| 87 | + """获取支持的动作列表""" |
| 88 | + supported_actions = set() |
| 89 | + for class_ in self.__class__.__mro__: |
| 90 | + for name, attr in class_.__dict__.items(): |
| 91 | + if not name.startswith("_") and getattr(attr, "__supported__", False): |
| 92 | + supported_actions.add(name) |
| 93 | + return list(supported_actions) |
| 94 | + |
| 95 | + @supported_action |
112 | 96 | async def get_supported_actions(self, **kwargs: Any) -> List[str]: |
113 | 97 | """获取支持的动作列表 |
114 | 98 |
|
115 | 99 | 参数: |
116 | 100 | kwargs: 扩展字段 |
117 | 101 | """ |
118 | | - return list(self.supported_actions) |
| 102 | + return self._supported_actions |
119 | 103 |
|
120 | 104 | async def _call_api(self, api: str, **kwargs: Any) -> Any: |
121 | | - if api not in await self.get_supported_actions(): |
| 105 | + if api not in self._supported_actions: |
122 | 106 | raise UnsupportedAction( |
123 | 107 | status="failed", |
124 | 108 | retcode=10002, |
|
0 commit comments