Skip to content

Commit 0d0ac28

Browse files
committed
refactor!: bulk of the work
large part left is the ConnectionState itself since events haven't been transformed yet
1 parent 5904f0a commit 0d0ac28

File tree

48 files changed

+1233
-1242
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1233
-1242
lines changed

discord/abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def __str__(self) -> str:
350350
def _sorting_bucket(self) -> int:
351351
raise NotImplementedError
352352

353-
def _update(self, guild: Guild, data: dict[str, Any]) -> None:
353+
async def _update(self, data: dict[str, Any]) -> None:
354354
raise NotImplementedError
355355

356356
async def _move(
@@ -1283,7 +1283,7 @@ async def create_invite(
12831283
target_user_id=target_user.id if target_user else None,
12841284
target_application_id=target_application_id,
12851285
)
1286-
invite = Invite.from_incomplete(data=data, state=self._state)
1286+
invite = await Invite.from_incomplete(data=data, state=self._state)
12871287
if target_event:
12881288
invite.set_scheduled_event(target_event)
12891289
return invite

discord/app/cache.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
from collections import OrderedDict, deque
2626
from typing import Deque, Protocol
2727

28+
from discord import utils
2829
from discord.app.state import ConnectionState
2930
from discord.message import Message
3031

31-
from ..abc import PrivateChannel
32+
from ..abc import MessageableChannel, PrivateChannel
3233
from ..channel import DMChannel
3334
from ..emoji import AppEmoji, GuildEmoji
3435
from ..guild import Guild
@@ -41,6 +42,7 @@
4142
from ..types.emoji import Emoji as EmojiPayload
4243
from ..types.sticker import GuildSticker as GuildStickerPayload
4344
from ..types.channel import DMChannel as DMChannelPayload
45+
from ..types.message import Message as MessagePayload
4446

4547
class Cache(Protocol):
4648
# users
@@ -67,6 +69,9 @@ async def get_sticker(self, sticker_id: int) -> GuildSticker:
6769
async def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker:
6870
...
6971

72+
async def delete_sticker(self, sticker_id: int) -> None:
73+
...
74+
7075
# interactions
7176

7277
async def store_view(self, view: View, message_id: int | None) -> None:
@@ -89,7 +94,7 @@ async def get_all_modals(self) -> list[Modal]:
8994
async def get_all_guilds(self) -> list[Guild]:
9095
...
9196

92-
async def get_guild(self, id: int) -> Guild:
97+
async def get_guild(self, id: int) -> Guild | None:
9398
...
9499

95100
async def add_guild(self, guild: Guild) -> None:
@@ -136,18 +141,24 @@ async def get_private_channels(self) -> list[PrivateChannel]:
136141
async def get_private_channel(self, channel_id: int) -> PrivateChannel:
137142
...
138143

139-
async def store_private_channel(self, channel: PrivateChannel, channel_id: int) -> None:
144+
async def get_private_channel_by_user(self, user_id: int) -> PrivateChannel:
140145
...
141146

142-
# dm channels
147+
async def store_private_channel(self, channel: PrivateChannel) -> None:
148+
...
143149

144-
async def get_dm_channels(self) -> list[DMChannel]:
150+
# messages
151+
152+
async def store_message(self, message: MessagePayload, channel: MessageableChannel) -> Message:
153+
...
154+
155+
async def delete_message(self, message_id: int) -> None:
145156
...
146157

147-
async def get_dm_channel(self, channel_id: int) -> DMChannel:
158+
async def get_message(self, message_id: int) -> Message | None:
148159
...
149160

150-
async def store_dm_channel(self, channel: DMChannelPayload, channel_id: int) -> DMChannel:
161+
async def get_all_messages(self) -> list[Message]:
151162
...
152163

153164
def clear(self, views: bool = True) -> None:
@@ -211,6 +222,9 @@ async def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildS
211222
self._stickers[guild.id] = sticker
212223
return sticker
213224

225+
async def delete_sticker(self, sticker_id: int) -> None:
226+
self._stickers.pop(sticker_id, None)
227+
214228
# interactions
215229

216230
async def delete_view_on(self, message_id: int) -> View | None:
@@ -291,7 +305,13 @@ async def get_private_channels(self) -> list[PrivateChannel]:
291305
return list(self._private_channels.values())
292306

293307
async def get_private_channel(self, channel_id: int) -> PrivateChannel | None:
294-
return self._private_channels.get(channel_id)
308+
try:
309+
channel = self._private_channels[channel_id]
310+
except KeyError:
311+
return None
312+
else:
313+
self._private_channels.move_to_end(channel_id)
314+
return channel
295315

296316
async def store_private_channel(self, channel: PrivateChannel) -> None:
297317
channel_id = channel.id
@@ -304,3 +324,17 @@ async def store_private_channel(self, channel: PrivateChannel) -> None:
304324

305325
if isinstance(channel, DMChannel) and channel.recipient:
306326
self._private_channels_by_user[channel.recipient.id] = channel
327+
328+
async def get_private_channel_by_user(self, user_id: int) -> PrivateChannel | None:
329+
return self._private_channels_by_user.get(user_id)
330+
331+
# messages
332+
333+
async def store_message(self, message: MessagePayload, channel: MessageableChannel) -> Message:
334+
msg = await Message._from_data(state=self._state, channel=channel, data=message)
335+
336+
async def get_message(self, message_id: int) -> Message | None:
337+
return utils.find(lambda m: m.id == message_id, reversed(self._messages))
338+
339+
async def get_all_messages(self) -> list[Message]:
340+
return list(self._messages)

discord/app/events.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,17 @@
2222
DEALINGS IN THE SOFTWARE.
2323
"""
2424

25+
from abc import ABC
2526
from asyncio import Future
2627
import asyncio
27-
from typing import Any, Callable, Protocol, Self, Type, TypeVar
28+
from typing import Any, Callable, Self, Type, TypeVar
2829

2930
from .state import ConnectionState
3031

3132
T = TypeVar('T')
3233

3334

34-
class Event(Protocol):
35+
class Event(ABC):
3536
__event_name__: str
3637

3738
@classmethod

discord/app/state.py

Lines changed: 55 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,12 @@ def __init__(
114114
self.buffer: list[Member] = []
115115
self.waiters: list[asyncio.Future[list[Member]]] = []
116116

117-
def add_members(self, members: list[Member]) -> None:
117+
async def add_members(self, members: list[Member]) -> None:
118118
self.buffer.extend(members)
119119
if self.cache:
120120
guild = self.resolver(self.guild_id)
121+
if inspect.isawaitable(guild):
122+
guild = await guild
121123
if guild is None:
122124
return
123125

@@ -364,125 +366,87 @@ async def store_modal(self, modal: Modal) -> None:
364366
async def prevent_view_updates_for(self, message_id: int) -> View | None:
365367
return await self.cache.delete_view_on(message_id)
366368

367-
@property
368-
def persistent_views(self) -> Sequence[View]:
369-
return self._view_store.persistent_views
369+
async def get_persistent_views(self) -> Sequence[View]:
370+
views = await self.cache.get_all_views()
371+
persistent_views = {
372+
view.id: view
373+
for view in views
374+
if view.is_persistent()
375+
}
376+
return list(persistent_views.values())
370377

371-
@property
372-
def guilds(self) -> list[Guild]:
373-
return list(self._guilds.values())
378+
async def get_guilds(self) -> list[Guild]:
379+
return await self.cache.get_all_guilds()
374380

375-
def _get_guild(self, guild_id: int | None) -> Guild | None:
376-
# the keys of self._guilds are ints
377-
return self._guilds.get(guild_id) # type: ignore
381+
async def _get_guild(self, guild_id: int | None) -> Guild | None:
382+
return await self.cache.get_guild(guild_id)
378383

379-
def _add_guild(self, guild: Guild) -> None:
380-
self._guilds[guild.id] = guild
384+
async def _add_guild(self, guild: Guild) -> None:
385+
await self.cache.add_guild(guild)
381386

382-
def _remove_guild(self, guild: Guild) -> None:
383-
self._guilds.pop(guild.id, None)
387+
async def _remove_guild(self, guild: Guild) -> None:
388+
await self.cache.delete_guild(guild)
384389

385390
for emoji in guild.emojis:
386-
self._remove_emoji(emoji)
391+
await self.cache.delete_emoji(emoji)
387392

388393
for sticker in guild.stickers:
389-
self._stickers.pop(sticker.id, None)
394+
await self.cache.delete_sticker(sticker.id)
390395

391396
del guild
392397

393-
@property
394-
def emojis(self) -> list[GuildEmoji | AppEmoji]:
395-
return list(self._emojis.values())
398+
async def get_emojis(self) -> list[GuildEmoji | AppEmoji]:
399+
return await self.cache.get_all_emojis()
396400

397-
@property
398-
def stickers(self) -> list[GuildSticker]:
399-
return list(self._stickers.values())
401+
async def get_stickers(self) -> list[GuildSticker]:
402+
return await self.cache.get_all_stickers()
400403

401-
def get_emoji(self, emoji_id: int | None) -> GuildEmoji | AppEmoji | None:
402-
# the keys of self._emojis are ints
403-
return self._emojis.get(emoji_id) # type: ignore
404+
async def get_emoji(self, emoji_id: int | None) -> GuildEmoji | AppEmoji | None:
405+
return await self.get_emoji(emoji_id)
404406

405-
def _remove_emoji(self, emoji: GuildEmoji | AppEmoji) -> None:
406-
self._emojis.pop(emoji.id, None)
407+
async def _remove_emoji(self, emoji: GuildEmoji | AppEmoji) -> None:
408+
await self.cache.delete_emoji(emoji)
407409

408-
def get_sticker(self, sticker_id: int | None) -> GuildSticker | None:
409-
# the keys of self._stickers are ints
410-
return self._stickers.get(sticker_id) # type: ignore
410+
async def get_sticker(self, sticker_id: int | None) -> GuildSticker | None:
411+
return await self.cache.get_sticker(sticker_id)
411412

412-
@property
413-
def polls(self) -> list[Poll]:
414-
return list(self._polls.values())
413+
async def get_polls(self) -> list[Poll]:
414+
return await self.cache.get_all_polls()
415415

416-
def store_raw_poll(self, poll: PollPayload, raw):
416+
def create_poll(self, poll: PollPayload, raw) -> Poll:
417417
channel = self.get_channel(raw.channel_id) or PartialMessageable(
418418
state=self, id=raw.channel_id
419419
)
420420
message = channel.get_partial_message(raw.message_id)
421-
p = Poll.from_dict(poll, message)
422-
self._polls[message.id] = p
423-
return p
421+
return Poll.from_dict(poll, message)
424422

425-
def store_poll(self, poll: Poll, message_id: int):
426-
self._polls[message_id] = poll
423+
async def store_poll(self, poll: Poll, message_id: int):
424+
await self.cache.store_poll(poll, message_id)
427425

428-
def get_poll(self, message_id):
429-
return self._polls.get(message_id)
430-
431-
@property
432-
def private_channels(self) -> list[PrivateChannel]:
433-
return list(self._private_channels.values())
434-
435-
def _get_private_channel(self, channel_id: int | None) -> PrivateChannel | None:
436-
try:
437-
# the keys of self._private_channels are ints
438-
value = self._private_channels[channel_id] # type: ignore
439-
except KeyError:
440-
return None
441-
else:
442-
self._private_channels.move_to_end(channel_id) # type: ignore
443-
return value
426+
async def get_poll(self, message_id: int):
427+
return await self.cache.get_poll(message_id)
444428

445-
def _get_private_channel_by_user(self, user_id: int | None) -> DMChannel | None:
446-
# the keys of self._private_channels are ints
447-
return self._private_channels_by_user.get(user_id) # type: ignore
429+
async def get_private_channels(self) -> list[PrivateChannel]:
430+
return await self.cache.get_private_channels()
448431

449-
def _add_private_channel(self, channel: PrivateChannel) -> None:
450-
channel_id = channel.id
451-
self._private_channels[channel_id] = channel
432+
async def _get_private_channel(self, channel_id: int | None) -> PrivateChannel | None:
433+
return await self.cache.get_private_channel(channel_id)
452434

453-
if len(self._private_channels) > 128:
454-
_, to_remove = self._private_channels.popitem(last=False)
455-
if isinstance(to_remove, DMChannel) and to_remove.recipient:
456-
self._private_channels_by_user.pop(to_remove.recipient.id, None)
435+
async def _get_private_channel_by_user(self, user_id: int | None) -> DMChannel | None:
436+
return await self.cache.get_private_channel_by_user(user_id)
457437

458-
if isinstance(channel, DMChannel) and channel.recipient:
459-
self._private_channels_by_user[channel.recipient.id] = channel
438+
async def _add_private_channel(self, channel: PrivateChannel) -> None:
439+
await self.cache.store_private_channel(channel)
460440

461441
async def add_dm_channel(self, data: DMChannelPayload) -> DMChannel:
462442
# self.user is *always* cached when this is called
463443
channel = DMChannel(me=self.user, state=self, data=data) # type: ignore
464444
await channel._load()
465-
self._add_private_channel(channel)
445+
await self._add_private_channel(channel)
466446
return channel
467447

468-
def _remove_private_channel(self, channel: PrivateChannel) -> None:
469-
self._private_channels.pop(channel.id, None)
470-
if isinstance(channel, DMChannel):
471-
recipient = channel.recipient
472-
if recipient is not None:
473-
self._private_channels_by_user.pop(recipient.id, None)
474-
475-
def _get_message(self, msg_id: int | None) -> Message | None:
476-
return (
477-
utils.find(lambda m: m.id == msg_id, reversed(self._messages))
478-
if self._messages
479-
else None
480-
)
481-
482-
def _add_guild_from_data(self, data: GuildPayload) -> Guild:
483-
guild = Guild(data=data, state=self)
484-
self._add_guild(guild)
485-
return guild
448+
async def _get_message(self, msg_id: int | None) -> Message | None:
449+
return await self.cache.get_message(msg_id)
486450

487451
def _guild_needs_chunking(self, guild: Guild) -> bool:
488452
# If presences are enabled then we get back the old guild.large behaviour
@@ -492,12 +456,12 @@ def _guild_needs_chunking(self, guild: Guild) -> bool:
492456
and not (self._intents.presences and not guild.large)
493457
)
494458

495-
def _get_guild_channel(
459+
async def _get_guild_channel(
496460
self, data: MessagePayload, guild_id: int | None = None
497461
) -> tuple[Channel | Thread, Guild | None]:
498462
channel_id = int(data["channel_id"])
499463
try:
500-
guild = self._get_guild(int(guild_id or data["guild_id"]))
464+
guild = await self._get_guild(int(guild_id or data["guild_id"]))
501465
except KeyError:
502466
channel = DMChannel._from_message(self, channel_id)
503467
guild = None
@@ -1950,15 +1914,15 @@ def _upgrade_partial_emoji(
19501914
except KeyError:
19511915
return emoji
19521916

1953-
def get_channel(self, id: int | None) -> Channel | Thread | None:
1917+
async def get_channel(self, id: int | None) -> Channel | Thread | None:
19541918
if id is None:
19551919
return None
19561920

1957-
pm = self._get_private_channel(id)
1921+
pm = await self._get_private_channel(id)
19581922
if pm is not None:
19591923
return pm
19601924

1961-
for guild in self.guilds:
1925+
for guild in await self.cache.get_all_guilds():
19621926
channel = guild._resolve_channel(id)
19631927
if channel is not None:
19641928
return channel

discord/appinfo.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,13 @@ def cover_image(self) -> Asset | None:
252252
return None
253253
return Asset._from_cover_image(self._state, self.id, self._cover_image)
254254

255-
@property
256-
def guild(self) -> Guild | None:
255+
async def get_guild(self) -> Guild | None:
257256
"""If this application is a game sold on Discord,
258257
this field will be the guild to which it has been linked.
259258
260259
.. versionadded:: 1.3
261260
"""
262-
return self._state._get_guild(self.guild_id)
261+
return await self._state._get_guild(self.guild_id)
263262

264263
@property
265264
def summary(self) -> str | None:

0 commit comments

Comments
 (0)