|
33 | 33 | AsyncIterator, |
34 | 34 | Awaitable, |
35 | 35 | Callable, |
| 36 | + Generator, |
36 | 37 | List, |
37 | 38 | TypeVar, |
38 | 39 | Union, |
|
41 | 42 | from .audit_logs import AuditLogEntry |
42 | 43 | from .errors import NoMoreItems |
43 | 44 | from .object import Object |
44 | | -from .utils import maybe_coroutine, snowflake_time, time_snowflake |
| 45 | +from .utils import maybe_coroutine, snowflake_time, time_snowflake, warn_deprecated |
45 | 46 |
|
46 | 47 | __all__ = ( |
47 | 48 | "ReactionIterator", |
|
56 | 57 |
|
57 | 58 | if TYPE_CHECKING: |
58 | 59 | from .abc import Snowflake |
| 60 | + from .channel import MessageableChannel |
59 | 61 | from .guild import BanEntry, Guild |
60 | 62 | from .member import Member |
61 | | - from .message import Message |
| 63 | + from .message import Message, MessagePin |
62 | 64 | from .monetization import Entitlement, Subscription |
63 | 65 | from .scheduled_events import ScheduledEvent |
64 | 66 | from .threads import Thread |
65 | 67 | from .types.audit_log import AuditLog as AuditLogPayload |
66 | 68 | from .types.guild import Guild as GuildPayload |
67 | 69 | from .types.message import Message as MessagePayload |
| 70 | + from .types.message import MessagePin as MessagePinPayload |
68 | 71 | from .types.monetization import Entitlement as EntitlementPayload |
69 | 72 | from .types.monetization import Subscription as SubscriptionPayload |
70 | 73 | from .types.threads import Thread as ThreadPayload |
@@ -1198,3 +1201,85 @@ async def _retrieve_subscriptions_after_strategy(self, retrieve): |
1198 | 1201 | self.limit -= retrieve |
1199 | 1202 | self.after = Object(id=int(data[0]["id"])) |
1200 | 1203 | return data |
| 1204 | + |
| 1205 | + |
| 1206 | +class MessagePinIterator(_AsyncIterator["MessagePin"]): |
| 1207 | + def __init__( |
| 1208 | + self, |
| 1209 | + channel: MessageableChannel, |
| 1210 | + limit: int | None, |
| 1211 | + before: Snowflake | datetime.datetime | None = None, |
| 1212 | + ): |
| 1213 | + self._channel = channel |
| 1214 | + self.limit = limit |
| 1215 | + self.http = channel._state.http |
| 1216 | + |
| 1217 | + self.before: str | None |
| 1218 | + if before is None: |
| 1219 | + self.before = None |
| 1220 | + elif isinstance(before, datetime.datetime): |
| 1221 | + self.before = before.isoformat() |
| 1222 | + else: |
| 1223 | + self.before = snowflake_time(before.id).isoformat() |
| 1224 | + |
| 1225 | + self.update_before: Callable[[MessagePinPayload], str] = self.get_last_pinned |
| 1226 | + |
| 1227 | + self.endpoint = self.http.pins_from |
| 1228 | + |
| 1229 | + self.queue: asyncio.Queue[MessagePin] = asyncio.Queue() |
| 1230 | + self.has_more: bool = True |
| 1231 | + |
| 1232 | + async def next(self) -> MessagePin: |
| 1233 | + if self.queue.empty(): |
| 1234 | + await self.fill_queue() |
| 1235 | + |
| 1236 | + try: |
| 1237 | + return self.queue.get_nowait() |
| 1238 | + except asyncio.QueueEmpty: |
| 1239 | + raise NoMoreItems() |
| 1240 | + |
| 1241 | + @staticmethod |
| 1242 | + def get_last_pinned(data: MessagePinPayload) -> str: |
| 1243 | + return data["pinned_at"] |
| 1244 | + |
| 1245 | + async def fill_queue(self) -> None: |
| 1246 | + if not self.has_more: |
| 1247 | + raise NoMoreItems() |
| 1248 | + |
| 1249 | + if not hasattr(self, "channel"): |
| 1250 | + channel = await self._channel._get_channel() |
| 1251 | + self.channel = channel |
| 1252 | + |
| 1253 | + limit = 50 if self.limit is None else min(self.limit, 50) |
| 1254 | + data = await self.endpoint(self.channel.id, before=self.before, limit=limit) |
| 1255 | + |
| 1256 | + pins: list[MessagePinPayload] = data.get("items", []) |
| 1257 | + for d in pins: |
| 1258 | + self.queue.put_nowait(self.create_pin(d)) |
| 1259 | + |
| 1260 | + self.has_more = data.get("has_more", False) |
| 1261 | + if self.limit is not None: |
| 1262 | + self.limit -= len(pins) |
| 1263 | + if self.limit <= 0: |
| 1264 | + self.has_more = False |
| 1265 | + |
| 1266 | + if self.has_more: |
| 1267 | + self.before = self.update_before(pins[-1]) |
| 1268 | + |
| 1269 | + def create_pin(self, data: MessagePinPayload) -> MessagePin: |
| 1270 | + from .message import MessagePin |
| 1271 | + |
| 1272 | + return MessagePin(state=self.channel._state, channel=self.channel, data=data) |
| 1273 | + |
| 1274 | + async def retrieve_inner(self) -> list[Message]: |
| 1275 | + pins = await self.flatten() |
| 1276 | + return [p.message for p in pins] |
| 1277 | + |
| 1278 | + def __await__(self) -> Generator[Any, Any, MessagePin]: |
| 1279 | + warn_deprecated( |
| 1280 | + f"Messageable.pins() returning a list of Message", |
| 1281 | + since="2.7", |
| 1282 | + removed="3.0", |
| 1283 | + reference="The documentation of pins()", |
| 1284 | + ) |
| 1285 | + return self.retrieve_inner().__await__() |
0 commit comments