Skip to content

Commit 978e0c8

Browse files
committed
feat: subscriptions & related changes
1 parent b59ab2c commit 978e0c8

File tree

9 files changed

+382
-52
lines changed

9 files changed

+382
-52
lines changed

discord/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2048,7 +2048,7 @@ async def fetch_skus(self) -> list[SKU]:
20482048
The bot's SKUs.
20492049
"""
20502050
data = await self._connection.http.list_skus(self.application_id)
2051-
return [SKU(data=s) for s in data]
2051+
return [SKU(state=self._connection, data=s) for s in data]
20522052

20532053
def entitlements(
20542054
self,

discord/enums.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,14 @@ class PollLayoutType(Enum):
10531053
default = 1
10541054

10551055

1056+
class SubscriptionStatus(Enum):
1057+
"""The status of a subscription."""
1058+
1059+
active = 1
1060+
ending = 2
1061+
inactive = 3
1062+
1063+
10561064
T = TypeVar("T")
10571065

10581066

discord/http.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3027,6 +3027,41 @@ def delete_test_entitlement(
30273027
)
30283028
return self.request(r)
30293029

3030+
async def list_sku_subscriptions(
3031+
self,
3032+
sku_id: Snowflake,
3033+
*,
3034+
before: Snowflake | None = None,
3035+
after: Snowflake | None = None,
3036+
limit: int = 50,
3037+
user_id: Snowflake | None = None,
3038+
) -> Response[list[monetization.Subscription]]:
3039+
params: dict[str, Any] = {}
3040+
if before is not None:
3041+
params["before"] = before
3042+
if after is not None:
3043+
params["after"] = after
3044+
if limit is not None:
3045+
params["limit"] = limit
3046+
if user_id is not None:
3047+
params["user_id"] = user_id
3048+
return self.request(
3049+
Route("GET", "/skus/{sku_id}/subscriptions", sku_id=sku_id),
3050+
params=params,
3051+
)
3052+
3053+
async def get_subscription(
3054+
self,
3055+
sku_id: Snowflake,
3056+
subscription_id: Snowflake,
3057+
) -> Response[monetization.Subscription]:
3058+
return self.request(
3059+
Route(
3060+
"GET", "/skus/{sku_id}/subscriptions/{subscription_id}",
3061+
sku_id=sku_id, subscription_id=subscription_id
3062+
)
3063+
)
3064+
30303065
# Onboarding
30313066

30323067
def get_onboarding(self, guild_id: Snowflake) -> Response[onboarding.Onboarding]:

discord/iterators.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"MemberIterator",
5353
"ScheduledEventSubscribersIterator",
5454
"EntitlementIterator",
55+
"SubscriptionIterator",
5556
)
5657

5758
if TYPE_CHECKING:
@@ -67,6 +68,7 @@
6768
from .types.threads import Thread as ThreadPayload
6869
from .types.user import PartialUser as PartialUserPayload
6970
from .user import User
71+
from .types.monetization import Subscription as SubscriptionPayload
7072

7173
T = TypeVar("T")
7274
OT = TypeVar("OT")
@@ -1041,3 +1043,105 @@ async def fill_entitlements(self):
10411043

10421044
for element in reversed(data):
10431045
await self.entitlements.put(Entitlement(data=element, state=self.state))
1046+
1047+
1048+
class SubscriptionIterator(_AsyncIterator["Subscription"]):
1049+
def __init__(
1050+
self,
1051+
state,
1052+
sku_id: int,
1053+
limit: int = None,
1054+
before: datetime.datetime | None = None,
1055+
after: datetime.datetime | None = None,
1056+
user_id: int | None = None
1057+
):
1058+
if isinstance(before, datetime.datetime):
1059+
before = Object(id=time_snowflake(before, high=False))
1060+
if isinstance(after, datetime.datetime):
1061+
after = Object(id=time_snowflake(after, high=True))
1062+
1063+
self.state = state
1064+
self.sku_id = sku_id
1065+
self.limit = limit
1066+
self.before = before
1067+
self.after = after
1068+
self.user_id = user_id
1069+
1070+
self._filter = None
1071+
1072+
self.get_subscriptions = state.http.list_sku_subscriptions
1073+
self.subscriptions = asyncio.Queue()
1074+
1075+
if self.before and self.after:
1076+
self._retrieve_subscriptions = self._retrieve_subscriptions_before_strategy # type: ignore
1077+
self._filter = lambda m: int(m["id"]) > self.after.id
1078+
elif self.after:
1079+
self._retrieve_subscriptions = self._retrieve_subscriptions_after_strategy # type: ignore
1080+
else:
1081+
self._retrieve_subscriptions = self._retrieve_subscriptions_before_strategy # type: ignore
1082+
1083+
async def next(self) -> Guild:
1084+
if self.subscriptions.empty():
1085+
await self.fill_subscriptions()
1086+
1087+
try:
1088+
return self.subscriptions.get_nowait()
1089+
except asyncio.QueueEmpty:
1090+
raise NoMoreItems()
1091+
1092+
def _get_retrieve(self):
1093+
l = self.limit
1094+
if l is None or l > 100:
1095+
r = 100
1096+
else:
1097+
r = l
1098+
self.retrieve = r
1099+
return r > 0
1100+
1101+
def create_subscription(self, data):
1102+
from .monetization import Subscription
1103+
1104+
return Subscription(state=self.state, data=data)
1105+
1106+
async def fill_subscriptions(self):
1107+
if self._get_retrieve():
1108+
data = await self._retrieve_subscriptions(self.retrieve)
1109+
if self.limit is None or len(data) < 100:
1110+
self.limit = 0
1111+
1112+
if self._filter:
1113+
data = filter(self._filter, data)
1114+
1115+
for element in data:
1116+
await self.subscriptions.put(self.create_subscription(element))
1117+
1118+
async def _retrieve_subscriptions(self, retrieve) -> list[SubscriptionPayload]:
1119+
raise NotImplementedError
1120+
1121+
async def _retrieve_subscriptions_before_strategy(self, retrieve):
1122+
before = self.before.id if self.before else None
1123+
data: list[SubscriptionPayload] = await self.get_subscriptions(
1124+
self.sku_id,
1125+
limit=retrieve,
1126+
before=before,
1127+
user_id=self.user_id,
1128+
)
1129+
if len(data):
1130+
if self.limit is not None:
1131+
self.limit -= retrieve
1132+
self.before = Object(id=int(data[-1]["id"]))
1133+
return data
1134+
1135+
async def _retrieve_subscriptions_after_strategy(self, retrieve):
1136+
after = self.after.id if self.after else None
1137+
data: list[SubscriptionPayload] = await self.get_subscriptions(
1138+
self.sku_id,
1139+
limit=retrieve,
1140+
after=after,
1141+
user_id=self.user_id,
1142+
)
1143+
if len(data):
1144+
if self.limit is not None:
1145+
self.limit -= retrieve
1146+
self.after = Object(id=int(data[0]["id"]))
1147+
return data

discord/monetization.py

Lines changed: 113 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,20 @@
2727

2828
from typing import TYPE_CHECKING
2929

30-
from .enums import EntitlementType, SKUType, try_enum
30+
from .abc import Snowflake, SnowflakeTime
31+
from .enums import EntitlementType, SKUType, SubscriptionStatus, try_enum
3132
from .flags import SKUFlags
33+
from .iterators import SubscriptionIterator
3234
from .mixins import Hashable
33-
from .utils import MISSING, _get_as_snowflake, parse_time
35+
from .utils import cached_property, cached_slot_property, MISSING, _get_as_snowflake, parse_time
3436

3537
if TYPE_CHECKING:
3638
from datetime import datetime
3739

3840
from .state import ConnectionState
3941
from .types.monetization import SKU as SKUPayload
4042
from .types.monetization import Entitlement as EntitlementPayload
43+
from .types.monetization import Subscription as SubscriptionPayload
4144

4245

4346
__all__ = (
@@ -68,6 +71,7 @@ class SKU(Hashable):
6871
"""
6972

7073
__slots__ = (
74+
"_state",
7175
"id",
7276
"type",
7377
"application_id",
@@ -76,7 +80,8 @@ class SKU(Hashable):
7680
"flags",
7781
)
7882

79-
def __init__(self, *, data: SKUPayload) -> None:
83+
def __init__(self, *, state: ConnectionState, data: SKUPayload) -> None:
84+
self._state: ConnectionState = state
8085
self.id: int = int(data["id"])
8186
self.type: SKUType = try_enum(SKUType, data["type"])
8287
self.application_id: int = int(data["application_id"])
@@ -101,12 +106,48 @@ def url(self) -> str:
101106
""":class:`str`: Returns the URL for the SKU."""
102107
return f"https://discord.com/application-directory/{self.application_id}/store/{self.id}"
103108

109+
def list_subscriptions(
110+
self,
111+
user: Snowflake, # user is required because this is a bot, we are not using oauth2
112+
*,
113+
before: SnowflakeTime | None = None,
114+
after: SnowflakeTime | None = None,
115+
limit: int | None = 100,
116+
) -> SubscriptionIterator:
117+
"""Returns an :class:`.AsyncIterator` that enables fetching the SKU's subscriptions.
118+
119+
.. versionadded:: 2.7
120+
121+
Parameters
122+
----------
123+
user: :class:`.abc.Snowflake`
124+
The user to retrieve subscriptions for.
125+
before: :class:`.abc.Snowflake` | :class:`datetime.datetime` | None
126+
Retrieves subscriptions before this date or object.
127+
If a datetime is provided, it is recommended to use a UTC-aware datetime.
128+
If the datetime is naive, it is assumed to be local time.
129+
after: :class:`.abc.Snowflake` | :class:`datetime.datetime` | None
130+
Retrieve subscriptions after this date or object.
131+
If a datetime is provided, it is recommended to use a UTC-aware datetime.
132+
If the datetime is naive, it is assumed to be local time.
133+
limit: :class:`int` | None
134+
The number of subscriptions to retrieve. If ``None``, retrieves all subscriptions.
135+
"""
136+
return SubscriptionIterator(self._state, self.id, user_id=user.id, before=before, after=after, limit=limit)
137+
104138

105139
class Entitlement(Hashable):
106140
"""Represents a Discord entitlement.
107141
108142
.. versionadded:: 2.5
109143
144+
.. notice::
145+
146+
As of October 1, 2024, entitlements that have been purchased will have ``ends_at`` set to ``None``
147+
unless the parent :class:`Subscription` has been cancelled.
148+
149+
`See the Discord changelog. <https://discord.com/developers/docs/change-log#premium-apps-entitlement-migration-and-new-subscription-api>`_
150+
110151
Attributes
111152
----------
112153
id: :class:`int`
@@ -130,7 +171,7 @@ class Entitlement(Hashable):
130171
consumed: :class:`bool`
131172
Whether or not this entitlement has been consumed.
132173
This will always be ``False`` for entitlements that are not
133-
of type :attr:`EntitlementType.consumable`.
174+
from an SKU of type :attr:`SKUType.consumable`.
134175
"""
135176

136177
__slots__ = (
@@ -158,7 +199,7 @@ def __init__(self, *, data: EntitlementPayload, state: ConnectionState) -> None:
158199
self.starts_at: datetime | MISSING = (
159200
parse_time(data.get("starts_at")) or MISSING
160201
)
161-
self.ends_at: datetime | MISSING = parse_time(data.get("ends_at")) or MISSING
202+
self.ends_at: datetime | MISSING | None = parse_time(ea) if (ea := data.get("ends_at")) is not None else MISSING
162203
self.guild_id: int | MISSING = _get_as_snowflake(data, "guild_id") or MISSING
163204
self.consumed: bool = data.get("consumed", False)
164205

@@ -177,18 +218,13 @@ async def consume(self) -> None:
177218
178219
Consumes this entitlement.
179220
180-
This can only be done on entitlements of type :attr:`EntitlementType.consumable`.
221+
This can only be done on entitlements from an SKU of type :attr:`SKUType.consumable`.
181222
182223
Raises
183224
------
184-
TypeError
185-
The entitlement is not consumable.
186225
HTTPException
187226
Consuming the entitlement failed.
188227
"""
189-
if self.type is not EntitlementType.consumable:
190-
raise TypeError("Cannot consume non-consumable entitlement")
191-
192228
await self._state.http.consume_entitlement(self._state.application_id, self.id)
193229
self.consumed = True
194230

@@ -205,3 +241,69 @@ async def delete(self) -> None:
205241
Deleting the entitlement failed.
206242
"""
207243
await self._state.http.delete_test_entitlement(self.application_id, self.id)
244+
245+
246+
class Subscription(Hashable):
247+
"""Represents a user making recurring payments for one or more SKUs.
248+
249+
Successful payments grant the user access to entitlements associated with the SKU.
250+
251+
.. versionadded:: 2.7
252+
253+
Attributes
254+
----------
255+
id: :class:`int`
256+
The subscription's ID.
257+
user_id: :class:`int`
258+
The ID of the user that owns this subscription.
259+
sku_ids: List[:class:`int`]
260+
The IDs of the SKUs this subscription is for.
261+
entitlement_ids: List[:class:`int`]
262+
The IDs of the entitlements this subscription is for.
263+
current_period_start: :class:`datetime.datetime`
264+
The start of the current subscription period.
265+
current_period_end: :class:`datetime.datetime`
266+
The end of the current subscription period.
267+
status: :class:`SubscriptionStatus`
268+
The status of the subscription.
269+
canceled_at: :class:`datetime.datetime` | ``None``
270+
When the subscription was canceled.
271+
"""
272+
__slots__ = (
273+
"_state",
274+
"id",
275+
"user_id",
276+
"sku_ids",
277+
"entitlement_ids",
278+
"current_period_start",
279+
"current_period_end",
280+
"status",
281+
"canceled_at",
282+
"country",
283+
)
284+
285+
def __init__(self, *, state: ConnectionState, data: SubscriptionPayload) -> None:
286+
self._state: ConnectionState = state
287+
self.id: int = int(data["id"])
288+
self.user_id: int = int(data["user_id"])
289+
self.sku_ids: list[int] = list(map(int, data["sku_ids"]))
290+
self.entitlement_ids: list[int] = list(map(int, data["entitlement_ids"]))
291+
self.current_period_start: datetime = parse_time(data["current_period_start"])
292+
self.current_period_end: datetime = parse_time(data["current_period_end"])
293+
self.status: SubscriptionStatus = try_enum(SubscriptionStatus, data["status"])
294+
self.canceled_at: datetime | None = parse_time(data.get("canceled_at"))
295+
self.country: str | None = data.get("country") # Not documented, it is only available with oauth2, not bots
296+
297+
def __repr__(self) -> str:
298+
return (
299+
f"<Subscription id={self.id} user_id={self.user_id} status={self.status}>"
300+
)
301+
302+
def __eq__(self, other: object) -> bool:
303+
return isinstance(other, self.__class__) and other.id == self.id
304+
305+
@property
306+
def user(self):
307+
"""Optional[:class:`User`]: The user that owns this subscription."""
308+
return self._state.get_user(self.user_id)
309+

0 commit comments

Comments
 (0)