|
33 | 33 | AsyncIterator,
|
34 | 34 | Awaitable,
|
35 | 35 | Callable,
|
| 36 | + Generator, |
36 | 37 | List,
|
37 | 38 | TypeVar,
|
38 | 39 | Union,
|
|
41 | 42 | from .audit_logs import AuditLogEntry
|
42 | 43 | from .errors import NoMoreItems
|
43 | 44 | from .object import Object
|
44 |
| -from .utils import maybe_coroutine, snowflake_time, time_snowflake |
| 45 | +from .utils import maybe_coroutine, snowflake_time, time_snowflake, warn_deprecated |
45 | 46 |
|
46 | 47 | __all__ = (
|
47 | 48 | "ReactionIterator",
|
|
56 | 57 |
|
57 | 58 | if TYPE_CHECKING:
|
58 | 59 | from .abc import Snowflake
|
| 60 | + from .channel import MessageableChannel |
59 | 61 | from .guild import BanEntry, Guild
|
60 | 62 | from .member import Member
|
61 |
| - from .message import Message |
| 63 | + from .message import Message, MessagePin |
62 | 64 | from .monetization import Entitlement, Subscription
|
63 | 65 | from .scheduled_events import ScheduledEvent
|
64 | 66 | from .threads import Thread
|
65 | 67 | from .types.audit_log import AuditLog as AuditLogPayload
|
66 | 68 | from .types.guild import Guild as GuildPayload
|
67 | 69 | from .types.message import Message as MessagePayload
|
| 70 | + from .types.message import MessagePin as MessagePinPayload |
68 | 71 | from .types.monetization import Entitlement as EntitlementPayload
|
69 | 72 | from .types.monetization import Subscription as SubscriptionPayload
|
70 | 73 | from .types.threads import Thread as ThreadPayload
|
@@ -1198,3 +1201,85 @@ async def _retrieve_subscriptions_after_strategy(self, retrieve):
|
1198 | 1201 | self.limit -= retrieve
|
1199 | 1202 | self.after = Object(id=int(data[0]["id"]))
|
1200 | 1203 | return data
|
| 1204 | + |
| 1205 | + |
| 1206 | +class MessagePinIterator(_AsyncIterator["MessagePin"]): |
| 1207 | + def __init__( |
| 1208 | + self, |
| 1209 | + channel: MessageableChannel, |
| 1210 | + limit: int | None, |
| 1211 | + before: Snowflake | datetime.datetime | None = None, |
| 1212 | + ): |
| 1213 | + self._channel = channel |
| 1214 | + self.limit = limit |
| 1215 | + self.http = channel._state.http |
| 1216 | + |
| 1217 | + self.before: str | None |
| 1218 | + if before is None: |
| 1219 | + self.before = None |
| 1220 | + elif isinstance(before, datetime.datetime): |
| 1221 | + self.before = before.isoformat() |
| 1222 | + else: |
| 1223 | + self.before = snowflake_time(before.id).isoformat() |
| 1224 | + |
| 1225 | + self.update_before: Callable[[MessagePinPayload], str] = self.get_last_pinned |
| 1226 | + |
| 1227 | + self.endpoint = self.http.pins_from |
| 1228 | + |
| 1229 | + self.queue: asyncio.Queue[MessagePin] = asyncio.Queue() |
| 1230 | + self.has_more: bool = True |
| 1231 | + |
| 1232 | + async def next(self) -> MessagePin: |
| 1233 | + if self.queue.empty(): |
| 1234 | + await self.fill_queue() |
| 1235 | + |
| 1236 | + try: |
| 1237 | + return self.queue.get_nowait() |
| 1238 | + except asyncio.QueueEmpty: |
| 1239 | + raise NoMoreItems() |
| 1240 | + |
| 1241 | + @staticmethod |
| 1242 | + def get_last_pinned(data: MessagePinPayload) -> str: |
| 1243 | + return data["pinned_at"] |
| 1244 | + |
| 1245 | + async def fill_queue(self) -> None: |
| 1246 | + if not self.has_more: |
| 1247 | + raise NoMoreItems() |
| 1248 | + |
| 1249 | + if not hasattr(self, "channel"): |
| 1250 | + channel = await self._channel._get_channel() |
| 1251 | + self.channel = channel |
| 1252 | + |
| 1253 | + limit = 50 if self.limit is None else min(self.limit, 50) |
| 1254 | + data = await self.endpoint(self.channel.id, before=self.before, limit=limit) |
| 1255 | + |
| 1256 | + pins: list[MessagePinPayload] = data.get("items", []) |
| 1257 | + for d in pins: |
| 1258 | + self.queue.put_nowait(self.create_pin(d)) |
| 1259 | + |
| 1260 | + self.has_more = data.get("has_more", False) |
| 1261 | + if self.limit is not None: |
| 1262 | + self.limit -= len(pins) |
| 1263 | + if self.limit <= 0: |
| 1264 | + self.has_more = False |
| 1265 | + |
| 1266 | + if self.has_more: |
| 1267 | + self.before = self.update_before(pins[-1]) |
| 1268 | + |
| 1269 | + def create_pin(self, data: MessagePinPayload) -> MessagePin: |
| 1270 | + from .message import MessagePin |
| 1271 | + |
| 1272 | + return MessagePin(state=self.channel._state, channel=self.channel, data=data) |
| 1273 | + |
| 1274 | + async def retrieve_inner(self) -> list[Message]: |
| 1275 | + pins = await self.flatten() |
| 1276 | + return [p.message for p in pins] |
| 1277 | + |
| 1278 | + def __await__(self) -> Generator[Any, Any, MessagePin]: |
| 1279 | + warn_deprecated( |
| 1280 | + f"Messageable.pins() returning a list of Message", |
| 1281 | + since="2.7", |
| 1282 | + removed="3.0", |
| 1283 | + reference="The documentation of pins()", |
| 1284 | + ) |
| 1285 | + return self.retrieve_inner().__await__() |
0 commit comments