64
64
from .types .audit_log import AuditLog as AuditLogPayload
65
65
from .types .guild import Guild as GuildPayload
66
66
from .types .message import Message as MessagePayload
67
+ from .types .monetization import Entitlement as EntitlementPayload
67
68
from .types .threads import Thread as ThreadPayload
68
69
from .types .user import PartialUser as PartialUserPayload
69
70
from .user import User
@@ -988,11 +989,21 @@ def __init__(
988
989
self .guild_id = guild_id
989
990
self .exclude_ended = exclude_ended
990
991
992
+ self ._filter = None
993
+
994
+ if self .before and self .after :
995
+ self ._retrieve_entitlements = self ._retrieve_entitlements_before_strategy
996
+ self ._filter = lambda e : int (e ["id" ]) > self .after .id
997
+ elif self .after :
998
+ self ._retrieve_entitlements = self ._retrieve_entitlements_after_strategy
999
+ else :
1000
+ self ._retrieve_entitlements = self ._retrieve_entitlements_before_strategy
1001
+
991
1002
self .state = state
992
1003
self .get_entitlements = state .http .list_entitlements
993
1004
self .entitlements = asyncio .Queue ()
994
1005
995
- async def next (self ) -> BanEntry :
1006
+ async def next (self ) -> Entitlement :
996
1007
if self .entitlements .empty ():
997
1008
await self .fill_entitlements ()
998
1009
@@ -1014,30 +1025,53 @@ async def fill_entitlements(self):
1014
1025
if not self ._get_retrieve ():
1015
1026
return
1016
1027
1028
+ data = await self ._retrieve_entitlements (self .retrieve )
1029
+
1030
+ if self ._filter :
1031
+ data = list (filter (self ._filter , data ))
1032
+
1033
+ if len (data ) < 100 :
1034
+ self .limit = 0 # terminate loop
1035
+
1036
+ for element in data :
1037
+ await self .entitlements .put (Entitlement (data = element , state = self .state ))
1038
+
1039
+ async def _retrieve_entitlements (self , retrieve ) -> list [Entitlement ]:
1040
+ """Retrieve entitlements and update next parameters."""
1041
+ raise NotImplementedError
1042
+
1043
+ async def _retrieve_entitlements_before_strategy (self , retrieve : int ) -> list [EntitlementPayload ]:
1044
+ """Retrieve entitlements using before parameter."""
1017
1045
before = self .before .id if self .before else None
1018
- after = self .after .id if self .after else None
1019
1046
data = await self .get_entitlements (
1020
1047
self .state .application_id ,
1021
1048
before = before ,
1022
- after = after ,
1023
- limit = self .retrieve ,
1049
+ limit = retrieve ,
1024
1050
user_id = self .user_id ,
1025
1051
guild_id = self .guild_id ,
1026
1052
sku_ids = self .sku_ids ,
1027
1053
exclude_ended = self .exclude_ended ,
1028
1054
)
1055
+ if len (data ):
1056
+ if self .limit is not None :
1057
+ self .limit -= retrieve
1058
+ self .before = Object (id = int (data [- 1 ]["id" ]))
1059
+ return data
1029
1060
1030
- if not data :
1031
- # no data, terminate
1032
- return
1033
-
1034
- if self .limit :
1035
- self .limit -= self .retrieve
1036
-
1037
- if len (data ) < 100 :
1038
- self .limit = 0 # terminate loop
1039
-
1040
- self .after = Object (id = int (data [- 1 ]["id" ]))
1041
-
1042
- for element in reversed (data ):
1043
- await self .entitlements .put (Entitlement (data = element , state = self .state ))
1061
+ async def _retrieve_entitlements_after_strategy (self , retrieve : int ) -> list [EntitlementPayload ]:
1062
+ """Retrieve entitlements using after parameter."""
1063
+ after = self .after .id if self .after else None
1064
+ data = await self .get_entitlements (
1065
+ self .state .application_id ,
1066
+ after = after ,
1067
+ limit = retrieve ,
1068
+ user_id = self .user_id ,
1069
+ guild_id = self .guild_id ,
1070
+ sku_ids = self .sku_ids ,
1071
+ exclude_ended = self .exclude_ended ,
1072
+ )
1073
+ if len (data ):
1074
+ if self .limit is not None :
1075
+ self .limit -= retrieve
1076
+ self .after = Object (id = int (data [- 1 ]["id" ]))
1077
+ return data
0 commit comments