Skip to content

Commit 71676ba

Browse files
committed
fix: EntitlementIterator behaviour and type-hinting
1 parent 8a09b89 commit 71676ba

File tree

1 file changed

+52
-18
lines changed

1 file changed

+52
-18
lines changed

discord/iterators.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from .types.audit_log import AuditLog as AuditLogPayload
6565
from .types.guild import Guild as GuildPayload
6666
from .types.message import Message as MessagePayload
67+
from .types.monetization import Entitlement as EntitlementPayload
6768
from .types.threads import Thread as ThreadPayload
6869
from .types.user import PartialUser as PartialUserPayload
6970
from .user import User
@@ -988,11 +989,21 @@ def __init__(
988989
self.guild_id = guild_id
989990
self.exclude_ended = exclude_ended
990991

992+
self._filter = None
993+
994+
if self.before and self.after:
995+
self._retrieve_entitlements = self._retrieve_entitlements_before_strategy
996+
self._filter = lambda e: int(e["id"]) > self.after.id
997+
elif self.after:
998+
self._retrieve_entitlements = self._retrieve_entitlements_after_strategy
999+
else:
1000+
self._retrieve_entitlements = self._retrieve_entitlements_before_strategy
1001+
9911002
self.state = state
9921003
self.get_entitlements = state.http.list_entitlements
9931004
self.entitlements = asyncio.Queue()
9941005

995-
async def next(self) -> BanEntry:
1006+
async def next(self) -> Entitlement:
9961007
if self.entitlements.empty():
9971008
await self.fill_entitlements()
9981009

@@ -1014,30 +1025,53 @@ async def fill_entitlements(self):
10141025
if not self._get_retrieve():
10151026
return
10161027

1028+
data = await self._retrieve_entitlements(self.retrieve)
1029+
1030+
if self._filter:
1031+
data = list(filter(self._filter, data))
1032+
1033+
if len(data) < 100:
1034+
self.limit = 0 # terminate loop
1035+
1036+
for element in data:
1037+
await self.entitlements.put(Entitlement(data=element, state=self.state))
1038+
1039+
async def _retrieve_entitlements(self, retrieve) -> list[Entitlement]:
1040+
"""Retrieve entitlements and update next parameters."""
1041+
raise NotImplementedError
1042+
1043+
async def _retrieve_entitlements_before_strategy(self, retrieve: int) -> list[EntitlementPayload]:
1044+
"""Retrieve entitlements using before parameter."""
10171045
before = self.before.id if self.before else None
1018-
after = self.after.id if self.after else None
10191046
data = await self.get_entitlements(
10201047
self.state.application_id,
10211048
before=before,
1022-
after=after,
1023-
limit=self.retrieve,
1049+
limit=retrieve,
10241050
user_id=self.user_id,
10251051
guild_id=self.guild_id,
10261052
sku_ids=self.sku_ids,
10271053
exclude_ended=self.exclude_ended,
10281054
)
1055+
if len(data):
1056+
if self.limit is not None:
1057+
self.limit -= retrieve
1058+
self.before = Object(id=int(data[-1]["id"]))
1059+
return data
10291060

1030-
if not data:
1031-
# no data, terminate
1032-
return
1033-
1034-
if self.limit:
1035-
self.limit -= self.retrieve
1036-
1037-
if len(data) < 100:
1038-
self.limit = 0 # terminate loop
1039-
1040-
self.after = Object(id=int(data[-1]["id"]))
1041-
1042-
for element in reversed(data):
1043-
await self.entitlements.put(Entitlement(data=element, state=self.state))
1061+
async def _retrieve_entitlements_after_strategy(self, retrieve: int) -> list[EntitlementPayload]:
1062+
"""Retrieve entitlements using after parameter."""
1063+
after = self.after.id if self.after else None
1064+
data = await self.get_entitlements(
1065+
self.state.application_id,
1066+
after=after,
1067+
limit=retrieve,
1068+
user_id=self.user_id,
1069+
guild_id=self.guild_id,
1070+
sku_ids=self.sku_ids,
1071+
exclude_ended=self.exclude_ended,
1072+
)
1073+
if len(data):
1074+
if self.limit is not None:
1075+
self.limit -= retrieve
1076+
self.after = Object(id=int(data[-1]["id"]))
1077+
return data

0 commit comments

Comments
 (0)