Skip to content

Commit 0f67b92

Browse files
committed
✨ use msg_id to cache generated unimsg
1 parent 53b105d commit 0f67b92

File tree

13 files changed

+113
-60
lines changed

13 files changed

+113
-60
lines changed

example/plugins/demo.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Union, Literal
22

33
from nonebot import require
4+
from nonebot.internal.adapter import Event
45
from nonebot.adapters.onebot.v12 import Bot
56
from importlib_metadata import distributions
67
from nonebot.adapters.onebot.v12.event import GroupMessageDeleteEvent
@@ -516,3 +517,43 @@ async def setu_h(count: int, tags: Match[tuple[str, ...]], r18: bool):
516517
@preview.handle()
517518
async def preview_h(content: UniMessage):
518519
await preview.finish("rendering preview:\n" + content)
520+
521+
522+
from arclet.alconna.tools.debug import analyse_header
523+
524+
525+
class TestExtension1(Extension):
526+
@property
527+
def priority(self) -> int:
528+
return 12
529+
530+
@property
531+
def id(self) -> str:
532+
return "test1"
533+
534+
def post_init(self, alc: Alconna) -> None:
535+
def test_func(msg: str):
536+
test_res = analyse_header(alc.prefixes, alc.command, msg, raise_exception=False)
537+
return test_res and test_res.matched
538+
539+
self.test_func = test_func
540+
541+
async def call_llm(self, msg: str):
542+
return "calculate " + "".join(c for c in msg if ord(c) < 256)
543+
544+
async def receive_wrapper(self, bot: Bot, event: Event, command: Alconna, receive: UniMessage) -> UniMessage:
545+
msg = receive.extract_plain_text()
546+
if self.test_func(msg):
547+
return receive
548+
output = await self.call_llm(msg)
549+
return UniMessage(output)
550+
551+
552+
calculate = Command("calculate <expression>", "计算").build(
553+
auto_send_output=True, extensions=[TestExtension1()], use_cmd_start=False
554+
)
555+
556+
557+
@calculate.handle()
558+
async def calculate_h(expression: str):
559+
await calculate.finish(expression)

src/nonebot_plugin_alconna/builtins/extensions/discord.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Optional
22

3-
from tarina import lang
3+
from tarina import LRU, lang
44
from arclet.alconna import Alconna
55
from nonebot.adapters import Event
66
from nonebot.typing import T_State
@@ -71,6 +71,7 @@ def __init__(
7171
self.nsfw = nsfw
7272
super().__init__()
7373
self.using = False
74+
self.cache: "LRU[int, UniMessage]" = LRU(20) # noqa: UP037
7475

7576
def post_init(self, alc: Alconna) -> None:
7677
if "/" not in alc.prefixes or (
@@ -123,6 +124,8 @@ def validate(self, bot, event: Event) -> bool:
123124
async def message_provider(self, event: Event, state: T_State, bot, use_origin: bool = False):
124125
if not isinstance(event, ApplicationCommandInteractionEvent):
125126
return None
127+
if event.id in self.cache:
128+
return self.cache[event.id]
126129
data = event.data
127130
cmd = f"/{data.name}"
128131

@@ -142,7 +145,9 @@ def _handle_options(options: list[ApplicationCommandInteractionDataOption]):
142145
cmd += " "
143146
cmd += " ".join(_handle_options(data.options))
144147

145-
return UniMessage(cmd.rstrip())
148+
res = UniMessage(cmd.rstrip())
149+
self.cache[event.id] = res
150+
return res
146151

147152
@classmethod
148153
async def send_deferred_response(cls) -> None:

src/nonebot_plugin_alconna/builtins/extensions/reply.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, add_left: bool = False, sep: str = " "):
7676
self.add_left = add_left
7777
self.sep = sep
7878

79-
cache: "LRU[int, UniMessage]" = LRU(20)
79+
cache: "LRU[str, UniMessage]" = LRU(20)
8080

8181
@property
8282
def priority(self) -> int:
@@ -87,17 +87,17 @@ def id(self) -> str:
8787
return "builtins.extensions.reply:ReplyMergeExtension"
8888

8989
async def message_provider(self, event, state, bot, use_origin: bool = False):
90-
event_id = id(event)
91-
if event_id in self.cache:
92-
return self.cache[event_id]
9390
if event.get_type() != "message":
9491
return None
9592
try:
9693
msg = event.get_message()
9794
except (NotImplementedError, ValueError):
9895
return None
96+
msg_id = UniMessage.get_message_id(event, bot)
97+
if msg_id in self.cache:
98+
return self.cache[msg_id]
9999
uni_msg = UniMessage.generate_sync(message=msg, bot=bot)
100-
self.cache[event_id] = uni_msg
100+
self.cache[msg_id] = uni_msg
101101
if not (reply := await reply_fetch(event, bot)):
102102
return uni_msg
103103
if not reply.msg:
@@ -109,11 +109,11 @@ async def message_provider(self, event, state, bot, use_origin: bool = False):
109109
if self.add_left:
110110
uni_msg_reply += self.sep
111111
uni_msg_reply.extend(uni_msg)
112-
self.cache[event_id] = uni_msg_reply
112+
self.cache[msg_id] = uni_msg_reply
113113
return uni_msg_reply
114114
uni_msg += self.sep
115115
uni_msg.extend(uni_msg_reply)
116-
self.cache[event_id] = uni_msg
116+
self.cache[msg_id] = uni_msg
117117
return uni_msg
118118

119119

src/nonebot_plugin_alconna/builtins/plugins/with/extension.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,21 @@ def id(self) -> str:
2323
prefixes: list[str]
2424
command: str
2525
sep: str
26-
cache: "LRU[int, UniMessage]" = LRU(20)
26+
cache: "LRU[str, UniMessage]" = LRU(20)
2727

2828
def post_init(self, alc: Alconna) -> None:
2929
self.prefixes = [pf for pf in alc.prefixes if isinstance(pf, str)]
3030
self.command = alc.header_display
3131
self.sep = alc.separators[0]
3232

3333
async def receive_wrapper(self, bot: Bot, event: Event, command: Alconna, receive: UniMessage) -> UniMessage:
34-
event_id = id(event)
35-
if event_id in self.cache:
36-
return self.cache[event_id]
34+
msg_id = UniMessage.get_message_id(event, bot)
35+
if msg_id in self.cache:
36+
return self.cache[msg_id]
3737
target = UniMessage.get_target(event, bot)
3838
prefix = self.supplier(target)
3939
if not prefix or not command.header_display.endswith(prefix):
4040
return receive
4141
res = UniMessage.text(random.choice(self.prefixes) + prefix + self.sep) + receive
42-
self.cache[event_id] = res
42+
self.cache[msg_id] = res
4343
return res

src/nonebot_plugin_alconna/extension.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def id(self) -> str:
125125

126126
_callbacks = set()
127127

128-
unimsg_cache: LRU[int, UniMessage] = LRU(16)
128+
unimsg_cache: LRU[str, UniMessage] = LRU(16)
129129

130130

131131
class ExtensionExecutor:
@@ -208,24 +208,17 @@ async def output_converter(self, output_type: OutputType, content: str) -> UniMe
208208
async def message_provider(
209209
self, event: Event, state: T_State, bot: Bot, use_origin: bool = False
210210
) -> UniMessage | None:
211-
event_id = id(event)
212-
if (uni_msg := unimsg_cache.get(event_id)) is not None:
213-
msg = uni_msg
214-
elif event.get_type() != "message":
215-
msg = None
216-
else:
217-
try:
218-
msg = event.get_message()
219-
except (NotImplementedError, ValueError):
220-
msg = None
211+
if event.get_type().startswith("message"):
212+
msg = event.get_message()
221213
if use_origin:
222-
try:
223-
msg = getattr(event, "original_message", msg) # type: ignore
224-
except (NotImplementedError, ValueError):
225-
pass
226-
if msg is not None:
214+
msg = getattr(event, "original_message", None) or msg # type: ignore
215+
msg_id = UniMessage.get_message_id(event, bot)
216+
if (uni_msg := unimsg_cache.get(msg_id)) is not None:
217+
msg = uni_msg
218+
else:
227219
msg = UniMessage.generate_without_reply(message=msg, bot=bot)
228-
unimsg_cache[event_id] = msg
220+
unimsg_cache[msg_id] = msg
221+
return msg
229222
exc = None
230223
for ext in self.context:
231224
if not ext._overrides["message_provider"]:
@@ -238,7 +231,7 @@ async def message_provider(
238231
if exc is not None:
239232
raise exc
240233

241-
return msg
234+
return None
242235

243236
async def receive_wrapper(self, bot: Bot, event: Event, command: Alconna, receive: UniMessage) -> UniMessage:
244237
res = receive

src/nonebot_plugin_alconna/rule.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import weakref
33
from typing import Any, Union, Literal, Optional
44

5+
import nonebot
56
from nonebot.typing import T_State
67
from tarina import lang, init_spec
78
from nonebot.matcher import Matcher
@@ -13,9 +14,6 @@
1314
from arclet.alconna.exceptions import SpecialOptionTriggered
1415
from arclet.alconna import Alconna, Arparma, CompSession, output_manager, command_manager
1516

16-
require("nonebot_plugin_waiter")
17-
from nonebot_plugin_waiter import waiter
18-
1917
from .i18n import Lang
2018
from .config import Config
2119
from .uniseg import UniMsg, UniMessage
@@ -24,7 +22,14 @@
2422
from .extension import Extension, ExtensionExecutor
2523
from .consts import ALCONNA_RESULT, ALCONNA_EXTENSION, ALCONNA_EXEC_RESULT, log
2624

27-
_modules = set()
25+
try:
26+
if nonebot._driver:
27+
require("nonebot_plugin_waiter")
28+
from nonebot_plugin_waiter import waiter
29+
else:
30+
waiter = None
31+
except (RuntimeError, ValueError):
32+
waiter = None
2833

2934

3035
def check_self_send(bot: Bot, event: Event) -> bool:
@@ -232,7 +237,7 @@ async def handle(
232237
def _checker(_event: Event):
233238
return session_id == _event.get_session_id()
234239

235-
w = waiter(["message"], Matcher, keep_session=True, block=False, rule=Rule(_checker))(self._waiter)
240+
w = waiter(["message"], Matcher, keep_session=True, block=False, rule=Rule(_checker))(self._waiter) # type: ignore
236241

237242
while interface.available:
238243

src/nonebot_plugin_alconna/uniseg/adapters/qq/exporter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,11 @@ async def send_to(self, target: Union[Target, Event], bot: Bot, message: Message
376376
message = message.exclude("mention_channel", "mention_user", "mention_everyone", "reference")
377377
if target.private:
378378
res = await bot.send_to_c2c(
379-
openid=target.id, message=message, msg_id=target.source, msg_seq=target.extra.get("qq.reply_seq"), **kwargs
379+
openid=target.id,
380+
message=message,
381+
msg_id=target.source,
382+
msg_seq=target.extra.get("qq.reply_seq"),
383+
**kwargs,
380384
)
381385
elif target.extra.get("qq.interaction", False):
382386
return await bot.send_to_group(group_openid=target.id, message=message, event_id=target.source, **kwargs)

src/nonebot_plugin_alconna/uniseg/message.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,8 @@ def get_message_id(event: Event | None = None, bot: Bot | None = None, adapter:
12681268
event = current_event.get()
12691269
except LookupError as e:
12701270
raise SerializeFailed(lang.require("nbp-uniseg", "event_missing")) from e
1271+
if hasattr(event, "__uniseg_message_id__"):
1272+
return event.__uniseg_message_id__ # type: ignore
12711273
if not adapter:
12721274
if not bot:
12731275
try:
@@ -1277,7 +1279,8 @@ def get_message_id(event: Event | None = None, bot: Bot | None = None, adapter:
12771279
_adapter = bot.adapter
12781280
adapter = _adapter.get_name()
12791281
if fn := alter_get_exporter(adapter):
1280-
return fn.get_message_id(event)
1282+
setattr(event, "__uniseg_message_id__", msg_id := fn.get_message_id(event))
1283+
return msg_id
12811284
raise SerializeFailed(lang.require("nbp-uniseg", "unsupported").format(adapter=adapter))
12821285

12831286
@staticmethod

tests/fake.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
from nonebot.adapters.discord.event import GuildMessageCreateEvent as DiscordMessageEvent
1313

1414

15+
_msg_ids = iter(range(1000000))
16+
17+
18+
def get_msg_id() -> int:
19+
return next(_msg_ids)
20+
21+
1522
def fake_group_message_event_v11(**field) -> "GroupMessageEventV11":
1623
from pydantic import create_model
1724
from nonebot.adapters.onebot.v11.event import Sender
@@ -27,7 +34,6 @@ class FakeEvent(_fake):
2734
user_id: int = 10
2835
message_type: Literal["group"] = "group"
2936
group_id: int = 10000
30-
message_id: int = 1
3137
message: Message = Message("test")
3238
raw_message: str = "test"
3339
font: int = 0
@@ -41,7 +47,7 @@ class FakeEvent(_fake):
4147
class Config:
4248
extra = "allow"
4349

44-
return FakeEvent(**field)
50+
return FakeEvent(message_id=get_msg_id(), **field)
4551

4652

4753
def fake_private_message_event_v11(**field) -> "PrivateMessageEventV11":
@@ -58,7 +64,6 @@ class FakeEvent(_fake):
5864
sub_type: str = "friend"
5965
user_id: int = 10
6066
message_type: Literal["private"] = "private"
61-
message_id: int = 1
6267
message: Message = Message("test")
6368
raw_message: str = "test"
6469
font: int = 0
@@ -68,7 +73,7 @@ class FakeEvent(_fake):
6873
class Config:
6974
extra = "forbid"
7075

71-
return FakeEvent(**field)
76+
return FakeEvent(message_id=get_msg_id(), **field)
7277

7378

7479
def fake_discord_interaction_event(**field) -> "ApplicationCommandInteractionEvent":
@@ -77,7 +82,7 @@ def fake_discord_interaction_event(**field) -> "ApplicationCommandInteractionEve
7782

7883
_fake = create_model("_fake", __base__=ApplicationCommandInteractionEvent)
7984
field["type"] = 2
80-
field["id"] = 123456
85+
field["id"] = get_msg_id() + 123456
8186
field["application_id"] = 123456789
8287
field["token"] = "sometoken" # noqa: S105
8388
field["version"] = 1
@@ -95,7 +100,7 @@ def fake_message_event_discord(content: str) -> "DiscordMessageEvent":
95100
return type_validate_python(
96101
GuildMessageCreateEvent,
97102
{
98-
"id": 11234,
103+
"id": get_msg_id() + 11234,
99104
"channel_id": 5566,
100105
"guild_id": 6677,
101106
"author": {
@@ -153,7 +158,7 @@ class Config:
153158
extra = "allow"
154159

155160
_message = field.pop("message", Message("test"))
156-
event = FakeEvent(message={"id": "1", "content": "text"}, **field) # type: ignore
161+
event = FakeEvent(message={"id": str(get_msg_id()), "content": "text"}, **field) # type: ignore
157162
event._message = _message
158163
event.original_message = _message
159164
return event
@@ -168,7 +173,6 @@ def fake_message_event_guild(**field) -> "MessageCreateEvent":
168173
_fake = create_model("_fake", __base__=MessageCreateEvent)
169174

170175
class FakeEvent(_fake):
171-
id: str = "1234"
172176
channel_id: str = "abcd"
173177
guild_id: str = "efgh"
174178
content: str = "test"
@@ -178,7 +182,7 @@ class FakeEvent(_fake):
178182
class Config:
179183
extra = "forbid"
180184

181-
return FakeEvent(**field)
185+
return FakeEvent(id=str(get_msg_id() + 5555), **field)
182186

183187

184188
def fake_satori_bot_params(self_id: str = "test", platform: str = "test") -> dict:

tests/test_buttons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def _(row: int):
110110

111111
adapter = get_adapter(Adapter)
112112
bot = ctx.create_bot(base=Bot, adapter=adapter, bot_info=None)
113-
event = fake_message_event_guild(message=Message("test 3"), id="123")
113+
event = fake_message_event_guild(message=Message("test 3"))
114114
ctx.receive_event(bot, event)
115115
ctx.should_call_send(
116116
event,

0 commit comments

Comments
 (0)