Skip to content

Commit e793130

Browse files
committed
✅ Add tests for OneBotImplementation
1 parent fe56628 commit e793130

File tree

11 files changed

+107
-28
lines changed

11 files changed

+107
-28
lines changed

nonebot_plugin_all4one/middlewares/__init__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from uuid import uuid4
33
from datetime import datetime
44
from functools import partial
5-
from abc import ABC, abstractmethod
65
from asyncio import Queue as BaseQueue
6+
from abc import ABC, ABCMeta, abstractmethod
77
from typing import (
88
TYPE_CHECKING,
99
Any,
@@ -88,8 +88,21 @@ async def get(self):
8888
return event
8989

9090

91-
class Middleware(ABC):
92-
supported_actions: ClassVar[Set[str]] = set()
91+
class _MiddlewareMeta(type):
92+
def __new__(cls, name, bases, attrs):
93+
supported_actions = set()
94+
for base in bases:
95+
supported_actions.update(base.supported_actions)
96+
attrs["supported_actions"] = supported_actions
97+
return type.__new__(cls, name, bases, attrs)
98+
99+
100+
class MiddlewareMeta(_MiddlewareMeta, ABCMeta):
101+
pass
102+
103+
104+
class Middleware(metaclass=MiddlewareMeta):
105+
supported_actions: ClassVar[Set[str]]
93106

94107
def __init__(self, bot: Bot):
95108
self.bot = bot

nonebot_plugin_all4one/onebotimpl/__init__.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,18 @@
55
from datetime import datetime
66
from functools import partial
77
from contextlib import asynccontextmanager
8-
from typing import Any, Dict, List, Type, Union, Literal, Optional, AsyncGenerator, cast
8+
from typing import (
9+
Any,
10+
Set,
11+
Dict,
12+
List,
13+
Type,
14+
Union,
15+
Literal,
16+
Optional,
17+
AsyncGenerator,
18+
cast,
19+
)
920

1021
import msgpack
1122
from nonebot import Driver
@@ -311,7 +322,7 @@ async def _handle_ws(
311322
await t2
312323
t1.cancel()
313324

314-
async def bot_connect(self, bot: Bot) -> None:
325+
def bot_connect(self, bot: Bot) -> None:
315326
if (middleware := self._middlewares.get(bot.type, None)) is None:
316327
return
317328
middleware = middleware(bot)
@@ -444,36 +455,32 @@ async def _websocket_rev(
444455
)
445456
await asyncio.sleep(conn.reconnect_interval)
446457

447-
async def bot_disconnect(self, bot: Bot) -> None:
458+
def bot_disconnect(self, bot: Bot) -> None:
448459
if (middleware := self.middlewares.pop(bot.self_id, None)) is None:
449460
return
450461
for task in middleware.tasks:
451462
if not task.done():
452463
task.cancel()
453-
await asyncio.gather(*middleware.tasks)
464+
465+
def import_middlewares(self, middlewares: Optional[Set[str]] = None):
466+
if middlewares is None:
467+
middlewares = set(self.driver._adapters.keys())
468+
for middleware in middlewares:
469+
try:
470+
if middleware in MIDDLEWARE_MAP:
471+
module = importlib.import_module(
472+
f"nonebot_plugin_all4one.middlewares.{MIDDLEWARE_MAP[middleware]}"
473+
)
474+
self.register_middleware(getattr(module, "Middleware"))
475+
else:
476+
logger.warning(f"Can not find middleware for Adapter {middleware}")
477+
except Exception as e:
478+
logger.warning(f"Can not load middleware for Adapter {middleware}: {e}")
454479

455480
def setup(self):
456481
@self.driver.on_startup
457482
async def _():
458-
adapters = (
459-
self.driver._adapters.keys()
460-
if self.config.middlewares is None
461-
else self.config.middlewares
462-
)
463-
464-
for adapter in adapters:
465-
try:
466-
if adapter in MIDDLEWARE_MAP:
467-
module = importlib.import_module(
468-
f"nonebot_plugin_all4one.middlewares.{MIDDLEWARE_MAP[adapter]}"
469-
)
470-
self.register_middleware(getattr(module, "Middleware"))
471-
else:
472-
logger.warning(f"Can not find middleware for Adapter {adapter}")
473-
except Exception as e:
474-
logger.warning(
475-
f"Can not load middleware for Adapter {adapter}: {e}"
476-
)
483+
self.import_middlewares(self.config.middlewares)
477484

478485
@self.driver.on_shutdown
479486
async def _():
@@ -487,10 +494,10 @@ async def _():
487494
async def _(bot: Bot):
488495
if bot.self_id.startswith("a4o@"):
489496
return
490-
await self.bot_connect(bot)
497+
self.bot_connect(bot)
491498

492499
@self.driver.on_bot_disconnect
493500
async def _(bot: Bot):
494501
if bot.self_id.startswith("a4o@"):
495502
return
496-
await self.bot_disconnect(bot)
503+
self.bot_disconnect(bot)

tests/conftest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,29 @@ def load_adapters(nonebug_init: None):
2727
driver.register_adapter(TelegramAdapter)
2828
driver.register_adapter(ConsoleAdapter)
2929

30+
nonebot.require("nonebot_plugin_all4one")
31+
from nonebot_plugin_all4one import obimpl
32+
33+
obimpl.import_middlewares()
34+
35+
36+
@pytest.fixture
37+
def FakeMiddleware():
38+
from nonebot_plugin_all4one.middlewares import Middleware
39+
40+
class FakeMiddleware(Middleware):
41+
@classmethod
42+
def get_name(cls):
43+
return "fake"
44+
45+
def get_platform(self):
46+
return "fake"
47+
48+
async def to_onebot_event(self, event):
49+
return []
50+
51+
return FakeMiddleware
52+
3053

3154
@pytest.fixture
3255
async def app(nonebug_init: None, tmp_path: Path):
File renamed without changes.
File renamed without changes.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from nonebug import App
2+
from nonebot import get_driver
3+
4+
5+
async def test_bot_connect(app: App, FakeMiddleware):
6+
from nonebot_plugin_all4one import obimpl
7+
8+
obimpl.register_middleware(FakeMiddleware)
9+
10+
async with app.test_api() as ctx:
11+
bot = ctx.create_bot()
12+
obimpl.bot_connect(bot)
13+
assert obimpl.middlewares[bot.self_id].bot == bot

0 commit comments

Comments
 (0)