|
9 | 9 | from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, Protocol, TypeVar, Union |
10 | 10 | from typing_extensions import Self, SupportsIndex, deprecated |
11 | 11 |
|
| 12 | +from nonebot import get_driver |
12 | 13 | from nonebot.exception import FinishedException |
13 | 14 | from nonebot.internal.adapter import Bot, Event, Message |
| 15 | +from nonebot.internal.driver import HTTPClientMixin, Request |
14 | 16 | from nonebot.internal.matcher import current_bot, current_event |
15 | 17 | from tarina import lang |
16 | 18 | from tarina.context import ContextModel |
|
33 | 35 | I18n, |
34 | 36 | Image, |
35 | 37 | Keyboard, |
| 38 | + Media, |
36 | 39 | Reference, |
37 | 40 | RefNode, |
38 | 41 | Reply, |
@@ -1178,3 +1181,29 @@ def load(cls: type[UniMessage[Segment]], data: str | list[dict[str, Any]]): |
1178 | 1181 | else: |
1179 | 1182 | _data = data |
1180 | 1183 | return cls(get_segment_class(seg_data["type"]).load(seg_data) for seg_data in _data) |
| 1184 | + |
| 1185 | + async def download(self, stream: bool = False, **kwargs): |
| 1186 | + """将消息中的媒体链接下载为文件数据 |
| 1187 | +
|
| 1188 | + Args: |
| 1189 | + stream (bool, optional): 是否以流式下载. Defaults to False. |
| 1190 | + **kwargs: 传递给下载器的参数 |
| 1191 | + """ |
| 1192 | + driver = get_driver() |
| 1193 | + for media in self.select(Media): |
| 1194 | + if not media.url: |
| 1195 | + continue |
| 1196 | + if not isinstance(driver, HTTPClientMixin): |
| 1197 | + raise TypeError("Current driver does not support http client") |
| 1198 | + request = Request("GET", media.url) |
| 1199 | + sess = driver.get_session(**kwargs) |
| 1200 | + raw = b"" |
| 1201 | + if stream: |
| 1202 | + async for chunk in sess.stream_request(request): |
| 1203 | + raw += chunk.content # type: ignore |
| 1204 | + else: |
| 1205 | + response = await sess.request(request) |
| 1206 | + raw = response.content # type: ignore |
| 1207 | + media.url = None |
| 1208 | + media.raw = raw |
| 1209 | + return self |
0 commit comments