Skip to content

Commit 9c4c3b6

Browse files
committed
feat: first batch of changes
definitely doesn't work yet, just felt like there was way too much changes not to commit already
1 parent de83d48 commit 9c4c3b6

35 files changed

+567
-195
lines changed

discord/abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
from .member import Member
9090
from .message import Message, MessageReference, PartialMessage
9191
from .poll import Poll
92-
from .state import ConnectionState
92+
from .app.state import ConnectionState
9393
from .threads import Thread
9494
from .types.channel import Channel as ChannelPayload
9595
from .types.channel import GuildChannel as GuildChannelPayload
@@ -1669,7 +1669,7 @@ async def send(
16691669

16701670
ret = state.create_message(channel=channel, data=data)
16711671
if view:
1672-
state.store_view(view, ret.id)
1672+
await state.store_view(view, ret.id)
16731673
view.message = ret
16741674

16751675
if delete_after is not None:

discord/app/cache.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
"""
2+
The MIT License (MIT)
3+
4+
Copyright (c) 2021-present Pycord Development
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a
7+
copy of this software and associated documentation files (the "Software"),
8+
to deal in the Software without restriction, including without limitation
9+
the rights to use, copy, modify, merge, publish, distribute, sublicense,
10+
and/or sell copies of the Software, and to permit persons to whom the
11+
Software is furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in
14+
all copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
17+
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21+
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22+
DEALINGS IN THE SOFTWARE.
23+
"""
24+
25+
from collections import OrderedDict, deque
26+
from typing import Deque, Protocol
27+
28+
from discord.app.state import ConnectionState
29+
from discord.message import Message
30+
31+
from ..abc import PrivateChannel
32+
from ..channel import DMChannel
33+
from ..emoji import AppEmoji, GuildEmoji
34+
from ..guild import Guild
35+
from ..poll import Poll
36+
from ..sticker import GuildSticker, Sticker
37+
from ..ui.modal import Modal, ModalStore
38+
from ..ui.view import View, ViewStore
39+
from ..user import User
40+
from ..types.user import User as UserPayload
41+
from ..types.emoji import Emoji as EmojiPayload
42+
from ..types.sticker import GuildSticker as GuildStickerPayload
43+
from ..types.channel import DMChannel as DMChannelPayload
44+
45+
class Cache(Protocol):
46+
# users
47+
async def get_all_users(self) -> list[User]:
48+
...
49+
50+
async def store_user(self, payload: UserPayload) -> User:
51+
...
52+
53+
async def delete_user(self, user_id: int) -> None:
54+
...
55+
56+
async def get_user(self, user_id: int) -> User | None:
57+
...
58+
59+
# stickers
60+
61+
async def get_all_stickers(self) -> list[GuildSticker]:
62+
...
63+
64+
async def get_sticker(self, sticker_id: int) -> GuildSticker:
65+
...
66+
67+
async def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker:
68+
...
69+
70+
# interactions
71+
72+
async def store_view(self, view: View, message_id: int | None) -> None:
73+
...
74+
75+
async def delete_view_on(self, message_id: int) -> None:
76+
...
77+
78+
async def get_all_views(self) -> list[View]:
79+
...
80+
81+
async def store_modal(self, modal: Modal) -> None:
82+
...
83+
84+
async def get_all_modals(self) -> list[Modal]:
85+
...
86+
87+
# guilds
88+
89+
async def get_all_guilds(self) -> list[Guild]:
90+
...
91+
92+
async def get_guild(self, id: int) -> Guild:
93+
...
94+
95+
async def add_guild(self, guild: Guild) -> None:
96+
...
97+
98+
async def delete_guild(self, guild: Guild) -> None:
99+
...
100+
101+
# emojis
102+
103+
async def store_guild_emoji(self, guild: Guild, data: EmojiPayload) -> GuildEmoji:
104+
...
105+
106+
async def store_app_emoji(
107+
self, application_id: int, data: EmojiPayload
108+
) -> AppEmoji:
109+
...
110+
111+
async def get_all_emojis(self) -> list[GuildEmoji | AppEmoji]:
112+
...
113+
114+
async def get_emoji(self, emoji_id: int | None) -> GuildEmoji | AppEmoji | None:
115+
...
116+
117+
async def delete_emoji(self, emoji: GuildEmoji | AppEmoji) -> None:
118+
...
119+
120+
# polls
121+
122+
async def get_all_polls(self) -> list[Poll]:
123+
...
124+
125+
async def get_poll(self, message_id: int) -> Poll:
126+
...
127+
128+
async def store_poll(self, poll: Poll, message_id: int) -> None:
129+
...
130+
131+
# private channels
132+
133+
async def get_private_channels(self) -> list[PrivateChannel]:
134+
...
135+
136+
async def get_private_channel(self, channel_id: int) -> PrivateChannel:
137+
...
138+
139+
async def store_private_channel(self, channel: PrivateChannel, channel_id: int) -> None:
140+
...
141+
142+
# dm channels
143+
144+
async def get_dm_channels(self) -> list[DMChannel]:
145+
...
146+
147+
async def get_dm_channel(self, channel_id: int) -> DMChannel:
148+
...
149+
150+
async def store_dm_channel(self, channel: DMChannelPayload, channel_id: int) -> DMChannel:
151+
...
152+
153+
def clear(self, views: bool = True) -> None:
154+
...
155+
156+
class MemoryCache(Cache):
157+
def __init__(self, max_messages: int | None = None, *, state: ConnectionState):
158+
self._state = state
159+
self.max_messages = max_messages
160+
self.clear()
161+
162+
def clear(self, views: bool = True) -> None:
163+
self._users: dict[int, User] = {}
164+
self._guilds: dict[int, Guild] = {}
165+
self._polls: dict[int, Poll] = {}
166+
self._stickers: dict[int, list[GuildSticker]] = {}
167+
if views:
168+
self._views: dict[str, View] = {}
169+
self._modals: dict[str, Modal] = {}
170+
self._messages: Deque[Message] = deque(maxlen=self.max_messages)
171+
172+
self._emojis = dict[str, GuildEmoji | AppEmoji] = {}
173+
174+
self._private_channels: OrderedDict[int, PrivateChannel] = OrderedDict()
175+
self._private_channels_by_user: dict[int, DMChannel] = {}
176+
177+
# users
178+
async def get_all_users(self) -> list[User]:
179+
return list(self._users.values())
180+
181+
async def store_user(self, payload: UserPayload) -> User:
182+
user_id = int(payload["id"])
183+
try:
184+
return self._users[user_id]
185+
except KeyError:
186+
user = User(state=self, data=payload)
187+
if user.discriminator != "0000":
188+
self._users[user_id] = user
189+
user._stored = True
190+
return user
191+
192+
async def delete_user(self, user_id: int) -> None:
193+
self._users.pop(user_id, None)
194+
195+
async def get_user(self, user_id: int) -> User:
196+
return self._users.get(user_id)
197+
198+
# stickers
199+
200+
async def get_all_stickers(self) -> list[GuildSticker]:
201+
return list(self._stickers.values())
202+
203+
async def get_sticker(self, sticker_id: int) -> GuildSticker:
204+
return self._stickers.get(sticker_id)
205+
206+
async def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker:
207+
sticker = GuildSticker(state=self._state, data=data)
208+
try:
209+
self._stickers[guild.id].append(sticker)
210+
except KeyError:
211+
self._stickers[guild.id] = sticker
212+
return sticker
213+
214+
# interactions
215+
216+
async def delete_view_on(self, message_id: int) -> View | None:
217+
return self._views.pop(message_id, None)
218+
219+
async def store_view(self, view: View, message_id: int) -> None:
220+
self._views[message_id or view.id] = view
221+
222+
async def get_all_views(self) -> list[View]:
223+
return list(self._views.values())
224+
225+
async def store_modal(self, modal: Modal) -> None:
226+
self._modals[modal.custom_id] = modal
227+
228+
async def get_all_modals(self) -> list[Modal]:
229+
return list(self._modals.values())
230+
231+
# guilds
232+
233+
async def get_all_guilds(self) -> list[Guild]:
234+
return list(self._guilds.values())
235+
236+
async def get_guild(self, id: int) -> Guild | None:
237+
return self._guilds.get(id)
238+
239+
async def add_guild(self, guild: Guild) -> None:
240+
self._guilds[guild.id] = guild
241+
242+
async def delete_guild(self, guild: Guild) -> None:
243+
self._guilds.pop(guild.id, None)
244+
245+
# emojis
246+
247+
async def store_guild_emoji(self, guild: Guild, data: EmojiPayload) -> GuildEmoji:
248+
emoji = GuildEmoji(guild=guild, state=self._state, data=data)
249+
try:
250+
self._emojis[guild.id].append(emoji)
251+
except KeyError:
252+
self._emojis[guild.id] = [emoji]
253+
return emoji
254+
255+
async def store_app_emoji(
256+
self, application_id: int, data: EmojiPayload
257+
) -> AppEmoji:
258+
emoji = AppEmoji(application_id=application_id, state=self._state, data=data)
259+
try:
260+
self._emojis[application_id].append(emoji)
261+
except KeyError:
262+
self._emojis[application_id] = [emoji]
263+
return emoji
264+
265+
async def get_all_emojis(self) -> list[GuildEmoji | AppEmoji]:
266+
return list(self._emojis.values())
267+
268+
async def get_emoji(self, emoji_id: int | None) -> GuildEmoji | AppEmoji | None:
269+
return self._emojis.get(emoji_id)
270+
271+
async def delete_emoji(self, emoji: GuildEmoji | AppEmoji) -> None:
272+
if isinstance(emoji, AppEmoji):
273+
self._emojis[emoji.application_id].remove(emoji)
274+
else:
275+
self._emojis[emoji.guild_id].remove(emoji)
276+
277+
# polls
278+
279+
async def get_all_polls(self) -> list[Poll]:
280+
return list(self._polls.values())
281+
282+
async def get_poll(self, message_id: int) -> Poll | None:
283+
return self._polls.get(message_id)
284+
285+
async def store_poll(self, poll: Poll, message_id: int) -> None:
286+
self._polls[message_id] = poll
287+
288+
# private channels
289+
290+
async def get_private_channels(self) -> list[PrivateChannel]:
291+
return list(self._private_channels.values())
292+
293+
async def get_private_channel(self, channel_id: int) -> PrivateChannel | None:
294+
return self._private_channels.get(channel_id)
295+
296+
async def store_private_channel(self, channel: PrivateChannel) -> None:
297+
channel_id = channel.id
298+
self._private_channels[channel_id] = channel
299+
300+
if len(self._private_channels) > 128:
301+
_, to_remove = self._private_channels.popitem(last=False)
302+
if isinstance(to_remove, DMChannel) and to_remove.recipient:
303+
self._private_channels_by_user.pop(to_remove.recipient.id, None)
304+
305+
if isinstance(channel, DMChannel) and channel.recipient:
306+
self._private_channels_by_user[channel.recipient.id] = channel

0 commit comments

Comments
 (0)