|
33 | 33 | AsyncIterator,
|
34 | 34 | Awaitable,
|
35 | 35 | Callable,
|
| 36 | + Generator, |
36 | 37 | List,
|
37 | 38 | TypeVar,
|
38 | 39 | Union,
|
|
42 | 43 | from .errors import NoMoreItems
|
43 | 44 | from .object import Object
|
44 | 45 | from .utils import generate_snowflake, snowflake_time
|
45 |
| -from .utils.private import maybe_awaitable |
| 46 | +from .utils.private import maybe_awaitable, warn_deprecated |
46 | 47 |
|
47 | 48 | __all__ = (
|
48 | 49 | "ReactionIterator",
|
|
57 | 58 |
|
58 | 59 | if TYPE_CHECKING:
|
59 | 60 | from .abc import Snowflake
|
| 61 | + from .channel import MessageableChannel |
60 | 62 | from .guild import BanEntry, Guild
|
61 | 63 | from .member import Member
|
62 |
| - from .message import Message |
| 64 | + from .message import Message, MessagePin |
63 | 65 | from .monetization import Entitlement, Subscription
|
64 | 66 | from .scheduled_events import ScheduledEvent
|
65 | 67 | from .threads import Thread
|
66 | 68 | from .types.audit_log import AuditLog as AuditLogPayload
|
67 | 69 | from .types.guild import Guild as GuildPayload
|
68 | 70 | from .types.message import Message as MessagePayload
|
| 71 | + from .types.message import MessagePin as MessagePinPayload |
69 | 72 | from .types.monetization import Entitlement as EntitlementPayload
|
70 | 73 | from .types.monetization import Subscription as SubscriptionPayload
|
71 | 74 | from .types.threads import Thread as ThreadPayload
|
@@ -1157,3 +1160,85 @@ async def _retrieve_subscriptions_after_strategy(self, retrieve):
|
1157 | 1160 | self.limit -= retrieve
|
1158 | 1161 | self.after = Object(id=int(data[0]["id"]))
|
1159 | 1162 | return data
|
| 1163 | + |
| 1164 | + |
| 1165 | +class MessagePinIterator(_AsyncIterator["MessagePin"]): |
| 1166 | + def __init__( |
| 1167 | + self, |
| 1168 | + channel: MessageableChannel, |
| 1169 | + limit: int | None, |
| 1170 | + before: Snowflake | datetime.datetime | None = None, |
| 1171 | + ): |
| 1172 | + self._channel = channel |
| 1173 | + self.limit = limit |
| 1174 | + self.http = channel._state.http |
| 1175 | + |
| 1176 | + self.before: str | None |
| 1177 | + if before is None: |
| 1178 | + self.before = None |
| 1179 | + elif isinstance(before, datetime.datetime): |
| 1180 | + self.before = before.isoformat() |
| 1181 | + else: |
| 1182 | + self.before = snowflake_time(before.id).isoformat() |
| 1183 | + |
| 1184 | + self.update_before: Callable[[MessagePinPayload], str] = self.get_last_pinned |
| 1185 | + |
| 1186 | + self.endpoint = self.http.pins_from |
| 1187 | + |
| 1188 | + self.queue: asyncio.Queue[MessagePin] = asyncio.Queue() |
| 1189 | + self.has_more: bool = True |
| 1190 | + |
| 1191 | + async def next(self) -> MessagePin: |
| 1192 | + if self.queue.empty(): |
| 1193 | + await self.fill_queue() |
| 1194 | + |
| 1195 | + try: |
| 1196 | + return self.queue.get_nowait() |
| 1197 | + except asyncio.QueueEmpty: |
| 1198 | + raise NoMoreItems() |
| 1199 | + |
| 1200 | + @staticmethod |
| 1201 | + def get_last_pinned(data: MessagePinPayload) -> str: |
| 1202 | + return data["pinned_at"] |
| 1203 | + |
| 1204 | + async def fill_queue(self) -> None: |
| 1205 | + if not self.has_more: |
| 1206 | + raise NoMoreItems() |
| 1207 | + |
| 1208 | + if not hasattr(self, "channel"): |
| 1209 | + channel = await self._channel._get_channel() |
| 1210 | + self.channel = channel |
| 1211 | + |
| 1212 | + limit = 50 if self.limit is None else min(self.limit, 50) |
| 1213 | + data = await self.endpoint(self.channel.id, before=self.before, limit=limit) |
| 1214 | + |
| 1215 | + pins: list[MessagePinPayload] = data.get("items", []) |
| 1216 | + for d in pins: |
| 1217 | + self.queue.put_nowait(self.create_pin(d)) |
| 1218 | + |
| 1219 | + self.has_more = data.get("has_more", False) |
| 1220 | + if self.limit is not None: |
| 1221 | + self.limit -= len(pins) |
| 1222 | + if self.limit <= 0: |
| 1223 | + self.has_more = False |
| 1224 | + |
| 1225 | + if self.has_more: |
| 1226 | + self.before = self.update_before(pins[-1]) |
| 1227 | + |
| 1228 | + def create_pin(self, data: MessagePinPayload) -> MessagePin: |
| 1229 | + from .message import MessagePin |
| 1230 | + |
| 1231 | + return MessagePin(state=self.channel._state, channel=self.channel, data=data) |
| 1232 | + |
| 1233 | + async def retrieve_inner(self) -> list[Message]: |
| 1234 | + pins = await self.flatten() |
| 1235 | + return [p.message for p in pins] |
| 1236 | + |
| 1237 | + def __await__(self) -> Generator[Any, Any, MessagePin]: |
| 1238 | + warn_deprecated( |
| 1239 | + f"Messageable.pins() returning a list of Message", |
| 1240 | + since="2.7", |
| 1241 | + removed="3.0", |
| 1242 | + reference="The documentation of pins()", |
| 1243 | + ) |
| 1244 | + return self.retrieve_inner().__await__() |
0 commit comments