Skip to content

Commit 3532a01

Browse files
authored
♻️ 优化动作的类型标注 (#13)
1 parent ae2f734 commit 3532a01

File tree

2 files changed

+24
-36
lines changed

2 files changed

+24
-36
lines changed

nonebot_plugin_all4one/middlewares/__init__.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
import asyncio
22
from uuid import uuid4
33
from datetime import datetime
4-
from functools import partial
4+
from abc import ABC, abstractmethod
55
from asyncio import Queue as BaseQueue
6-
from abc import ABC, ABCMeta, abstractmethod
76
from typing import (
87
TYPE_CHECKING,
98
Any,
10-
Set,
119
Dict,
1210
List,
13-
Type,
1411
Union,
1512
Generic,
1613
Literal,
1714
TypeVar,
18-
ClassVar,
1915
Optional,
2016
)
2117

@@ -44,20 +40,6 @@
4440
}
4541

4642

47-
class supported_action:
48-
def __init__(self, fn):
49-
self.fn = fn
50-
51-
def __set_name__(self, owner: Type["Middleware"], name: str):
52-
owner.supported_actions.add(name)
53-
54-
def __call__(self, *args, **kwargs):
55-
return self.fn(*args, **kwargs)
56-
57-
def __get__(self, obj, objtype=None):
58-
return partial(self.__call__, obj)
59-
60-
6143
_T = TypeVar("_T", bound=OneBotEvent)
6244
if TYPE_CHECKING:
6345

@@ -88,37 +70,39 @@ async def get(self):
8870
return event
8971

9072

91-
class _MiddlewareMeta(type):
92-
def __new__(cls, name, bases, attrs):
93-
supported_actions = set()
94-
for base in bases:
95-
supported_actions.update(base.supported_actions)
96-
attrs["supported_actions"] = supported_actions
97-
return type.__new__(cls, name, bases, attrs)
98-
99-
100-
class MiddlewareMeta(_MiddlewareMeta, ABCMeta):
101-
pass
102-
73+
def supported_action(func):
74+
"""标记支持的动作"""
75+
func.__supported__ = True
76+
return func
10377

104-
class Middleware(metaclass=MiddlewareMeta):
105-
supported_actions: ClassVar[Set[str]]
10678

79+
class Middleware(ABC):
10780
def __init__(self, bot: Bot):
10881
self.bot = bot
10982
self.tasks: List[asyncio.Task] = []
11083
self.queues: List[Queue[OneBotEvent]] = []
84+
self._supported_actions = self._get_supported_actions()
11185

86+
def _get_supported_actions(self) -> List[str]:
87+
"""获取支持的动作列表"""
88+
supported_actions = set()
89+
for class_ in self.__class__.__mro__:
90+
for name, attr in class_.__dict__.items():
91+
if not name.startswith("_") and getattr(attr, "__supported__", False):
92+
supported_actions.add(name)
93+
return list(supported_actions)
94+
95+
@supported_action
11296
async def get_supported_actions(self, **kwargs: Any) -> List[str]:
11397
"""获取支持的动作列表
11498
11599
参数:
116100
kwargs: 扩展字段
117101
"""
118-
return list(self.supported_actions)
102+
return self._supported_actions
119103

120104
async def _call_api(self, api: str, **kwargs: Any) -> Any:
121-
if api not in await self.get_supported_actions():
105+
if api not in self._supported_actions:
122106
raise UnsupportedAction(
123107
status="failed",
124108
retcode=10002,

tests/onebotimpl/test_call_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,8 @@ async def test_get_supported_action(app: App, FakeMiddleware):
2020
bot = ctx.create_bot()
2121
middleware = FakeMiddleware(bot)
2222
supported_actions = await obimpl.get_supported_actions(middleware)
23-
assert set(supported_actions) == {"upload_file", "get_file"}
23+
assert set(supported_actions) == {
24+
"upload_file",
25+
"get_file",
26+
"get_supported_actions",
27+
}

0 commit comments

Comments
 (0)