Skip to content

Commit 79883a0

Browse files
committed
refactor!: miscellaneous
1 parent 79c6bf4 commit 79883a0

File tree

4 files changed

+34
-24
lines changed

4 files changed

+34
-24
lines changed

discord/app/cache.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ async def get_private_channels(self) -> list[PrivateChannel]:
147147
async def get_private_channel(self, channel_id: int) -> PrivateChannel:
148148
...
149149

150-
async def get_private_channel_by_user(self, user_id: int) -> PrivateChannel:
150+
async def get_private_channel_by_user(self, user_id: int) -> PrivateChannel | None:
151151
...
152152

153153
async def store_private_channel(self, channel: PrivateChannel) -> None:
@@ -190,19 +190,32 @@ async def get_guild_members(self, guild_id: int) -> list[Member]:
190190
async def get_all_members(self) -> list[Member]:
191191
...
192192

193-
def clear(self, views: bool = True) -> None:
193+
async def clear(self, views: bool = True) -> None:
194194
...
195195

196196
class MemoryCache(Cache):
197197
def __init__(self, max_messages: int | None = None, *, state: ConnectionState):
198198
self._state = state
199199
self.max_messages = max_messages
200-
self.clear()
200+
self._users: dict[int, User] = {}
201+
self._guilds: dict[int, Guild] = {}
202+
self._polls: dict[int, Poll] = {}
203+
self._stickers: dict[int, list[GuildSticker]] = {}
204+
self._views: dict[str, View] = {}
205+
self._modals: dict[str, Modal] = {}
206+
self._messages: Deque[Message] = deque(maxlen=self.max_messages)
207+
208+
self._emojis: dict[int, list[GuildEmoji | AppEmoji]] = {}
209+
210+
self._private_channels: OrderedDict[int, PrivateChannel] = OrderedDict()
211+
self._private_channels_by_user: dict[int, DMChannel] = {}
212+
213+
self._guild_members: dict[int, dict[int, Member]] = defaultdict(dict)
201214

202215
def _flatten(self, matrix: list[list[T]]) -> list[T]:
203216
return [item for row in matrix for item in row]
204217

205-
def clear(self, views: bool = True) -> None:
218+
async def clear(self, views: bool = True) -> None:
206219
self._users: dict[int, User] = {}
207220
self._guilds: dict[int, Guild] = {}
208221
self._polls: dict[int, Poll] = {}

discord/app/event_emitter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
DEALINGS IN THE SOFTWARE.
2323
"""
2424

25-
from abc import ABC
25+
from abc import ABC, abstractmethod
2626
from asyncio import Future
2727
import asyncio
28+
from collections import defaultdict
2829
from typing import Any, Callable, Self, TypeVar
2930

3031
from .state import ConnectionState

discord/app/state.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
Sequence,
4242
TypeVar,
4343
Union,
44+
cast,
4445
)
4546

4647
from discord.app.event_emitter import EventEmitter
@@ -263,9 +264,9 @@ def __init__(
263264

264265
self.cache: Cache = self.cache
265266

266-
def clear(self, *, views: bool = True) -> None:
267+
async def clear(self, *, views: bool = True) -> None:
267268
self.user: ClientUser | None = None
268-
self.cache.clear()
269+
await self.cache.clear()
269270
self._voice_clients: dict[int, VoiceClient] = {}
270271

271272
async def process_chunk_requests(
@@ -340,7 +341,7 @@ def deref_user_no_intents(self, user_id: int) -> None:
340341
return
341342

342343
async def get_user(self, id: int | None) -> User | None:
343-
return await self.cache.get_user(id)
344+
return await self.cache.get_user(cast(int, id))
344345

345346
async def store_emoji(self, guild: Guild, data: EmojiPayload) -> GuildEmoji:
346347
return await self.cache.store_guild_emoji(guild, data)
@@ -360,8 +361,8 @@ async def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildS
360361
async def store_view(self, view: View, message_id: int | None = None) -> None:
361362
await self.cache.store_view(view, message_id)
362363

363-
async def store_modal(self, modal: Modal) -> None:
364-
await self.cache.store_modal(modal)
364+
async def store_modal(self, modal: Modal, user_id: int) -> None:
365+
await self.cache.store_modal(modal, user_id)
365366

366367
async def prevent_view_updates_for(self, message_id: int) -> View | None:
367368
return await self.cache.delete_view_on(message_id)
@@ -379,7 +380,7 @@ async def get_guilds(self) -> list[Guild]:
379380
return await self.cache.get_all_guilds()
380381

381382
async def _get_guild(self, guild_id: int | None) -> Guild | None:
382-
return await self.cache.get_guild(guild_id)
383+
return await self.cache.get_guild(cast(int, guild_id))
383384

384385
async def _add_guild(self, guild: Guild) -> None:
385386
await self.cache.add_guild(guild)
@@ -408,18 +409,11 @@ async def _remove_emoji(self, emoji: GuildEmoji | AppEmoji) -> None:
408409
await self.cache.delete_emoji(emoji)
409410

410411
async def get_sticker(self, sticker_id: int | None) -> GuildSticker | None:
411-
return await self.cache.get_sticker(sticker_id)
412+
return await self.cache.get_sticker(cast(int, sticker_id))
412413

413414
async def get_polls(self) -> list[Poll]:
414415
return await self.cache.get_all_polls()
415416

416-
def create_poll(self, poll: PollPayload, raw) -> Poll:
417-
channel = self.get_channel(raw.channel_id) or PartialMessageable(
418-
state=self, id=raw.channel_id
419-
)
420-
message = channel.get_partial_message(raw.message_id)
421-
return Poll.from_dict(poll, message)
422-
423417
async def store_poll(self, poll: Poll, message_id: int):
424418
await self.cache.store_poll(poll, message_id)
425419

@@ -430,10 +424,10 @@ async def get_private_channels(self) -> list[PrivateChannel]:
430424
return await self.cache.get_private_channels()
431425

432426
async def _get_private_channel(self, channel_id: int | None) -> PrivateChannel | None:
433-
return await self.cache.get_private_channel(channel_id)
427+
return await self.cache.get_private_channel(cast(int, channel_id))
434428

435429
async def _get_private_channel_by_user(self, user_id: int | None) -> DMChannel | None:
436-
return await self.cache.get_private_channel_by_user(user_id)
430+
return cast(DMChannel | None, await self.cache.get_private_channel_by_user(cast(int, user_id)))
437431

438432
async def _add_private_channel(self, channel: PrivateChannel) -> None:
439433
await self.cache.store_private_channel(channel)
@@ -446,7 +440,7 @@ async def add_dm_channel(self, data: DMChannelPayload) -> DMChannel:
446440
return channel
447441

448442
async def _get_message(self, msg_id: int | None) -> Message | None:
449-
return await self.cache.get_message(msg_id)
443+
return await self.cache.get_message(cast(int, msg_id))
450444

451445
def _guild_needs_chunking(self, guild: Guild) -> bool:
452446
# If presences are enabled then we get back the old guild.large behaviour
@@ -461,7 +455,8 @@ async def _get_guild_channel(
461455
) -> tuple[Channel | Thread, Guild | None]:
462456
channel_id = int(data["channel_id"])
463457
try:
464-
guild = await self._get_guild(int(guild_id or data["guild_id"]))
458+
# guild_id is in data
459+
guild = await self._get_guild(int(guild_id or data["guild_id"])) # type: ignore
465460
except KeyError:
466461
channel = DMChannel._from_message(self, channel_id)
467462
guild = None

discord/interactions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,8 @@ async def send_modal(self, modal: Modal) -> Interaction:
12431243
)
12441244
)
12451245
self._responded = True
1246-
await self._parent._state.store_modal(modal)
1246+
# _data should be present
1247+
await self._parent._state.store_modal(modal, int(self._parent._data["user"]["id"])) # type: ignore
12471248
return self._parent
12481249

12491250
@utils.deprecated("a button with type ButtonType.premium", "2.6")

0 commit comments

Comments
 (0)