Skip to content

Commit 56fdd91

Browse files
committed
feat: add more events
1 parent 3a43db0 commit 56fdd91

File tree

18 files changed

+470
-267
lines changed

18 files changed

+470
-267
lines changed

discord/abc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,7 @@ def overwrites_for(self, obj: Role | User) -> PermissionOverwrite:
629629

630630
return PermissionOverwrite()
631631

632-
@property
633-
def overwrites(self) -> dict[Role | Member, PermissionOverwrite]:
632+
async def get_overwrites(self) -> dict[Role | Member, PermissionOverwrite]:
634633
"""Returns all of the channel's overwrites.
635634
636635
This is returned as a dictionary where the key contains the target which
@@ -652,7 +651,7 @@ def overwrites(self) -> dict[Role | Member, PermissionOverwrite]:
652651
if ow.is_role():
653652
target = self.guild.get_role(ow.id)
654653
elif ow.is_member():
655-
target = self.guild.get_member(ow.id)
654+
target = await self.guild.get_member(ow.id)
656655

657656
# TODO: There is potential data loss here in the non-chunked
658657
# case, i.e. target is None because get_member returned nothing.

discord/app/cache.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
DEALINGS IN THE SOFTWARE.
2323
"""
2424

25-
from collections import OrderedDict, deque
26-
from typing import Deque, Protocol
25+
from collections import OrderedDict, defaultdict, deque
26+
from typing import Deque, Protocol, TypeVar
2727

2828
from discord import utils
2929
from discord.app.state import ConnectionState
30+
from discord.member import Member
3031
from discord.message import Message
3132

3233
from ..abc import MessageableChannel, PrivateChannel
@@ -44,6 +45,8 @@
4445
from ..types.channel import DMChannel as DMChannelPayload
4546
from ..types.message import Message as MessagePayload
4647

48+
T = TypeVar('T')
49+
4750
class Cache(Protocol):
4851
# users
4952
async def get_all_users(self) -> list[User]:
@@ -167,6 +170,26 @@ async def get_message(self, message_id: int) -> Message | None:
167170
async def get_all_messages(self) -> list[Message]:
168171
...
169172

173+
# guild members
174+
175+
async def store_member(self, member: Member) -> None:
176+
...
177+
178+
async def get_member(self, guild_id: int, user_id: int) -> Member | None:
179+
...
180+
181+
async def delete_member(self, guild_id: int, user_id: int) -> None:
182+
...
183+
184+
async def delete_guild_members(self, guild_id: int) -> None:
185+
...
186+
187+
async def get_guild_members(self, guild_id: int) -> list[Member]:
188+
...
189+
190+
async def get_all_members(self) -> list[Member]:
191+
...
192+
170193
def clear(self, views: bool = True) -> None:
171194
...
172195

@@ -176,6 +199,9 @@ def __init__(self, max_messages: int | None = None, *, state: ConnectionState):
176199
self.max_messages = max_messages
177200
self.clear()
178201

202+
def _flatten(self, matrix: list[list[T]]) -> list[T]:
203+
return [item for row in matrix for item in row]
204+
179205
def clear(self, views: bool = True) -> None:
180206
self._users: dict[int, User] = {}
181207
self._guilds: dict[int, Guild] = {}
@@ -186,11 +212,13 @@ def clear(self, views: bool = True) -> None:
186212
self._modals: dict[str, Modal] = {}
187213
self._messages: Deque[Message] = deque(maxlen=self.max_messages)
188214

189-
self._emojis = dict[str, GuildEmoji | AppEmoji] = {}
215+
self._emojis: dict[int, list[GuildEmoji | AppEmoji]] = {}
190216

191217
self._private_channels: OrderedDict[int, PrivateChannel] = OrderedDict()
192218
self._private_channels_by_user: dict[int, DMChannel] = {}
193219

220+
self._guild_members: dict[int, dict[int, Member]] = defaultdict(dict)
221+
194222
# users
195223
async def get_all_users(self) -> list[User]:
196224
return list(self._users.values())
@@ -200,7 +228,7 @@ async def store_user(self, payload: UserPayload) -> User:
200228
try:
201229
return self._users[user_id]
202230
except KeyError:
203-
user = User(state=self, data=payload)
231+
user = User(state=self._state, data=payload)
204232
if user.discriminator != "0000":
205233
self._users[user_id] = user
206234
user._stored = True
@@ -209,16 +237,19 @@ async def store_user(self, payload: UserPayload) -> User:
209237
async def delete_user(self, user_id: int) -> None:
210238
self._users.pop(user_id, None)
211239

212-
async def get_user(self, user_id: int) -> User:
240+
async def get_user(self, user_id: int) -> User | None:
213241
return self._users.get(user_id)
214242

215243
# stickers
216244

217245
async def get_all_stickers(self) -> list[GuildSticker]:
218-
return list(self._stickers.values())
246+
return self._flatten(list(self._stickers.values()))
219247

220248
async def get_sticker(self, sticker_id: int) -> GuildSticker | None:
221-
return self._stickers.get(sticker_id)
249+
stickers = self._flatten(list(self._stickers.values()))
250+
for sticker in stickers:
251+
if sticker.id == sticker_id:
252+
return sticker
222253

223254
async def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker:
224255
sticker = GuildSticker(state=self._state, data=data)
@@ -285,10 +316,13 @@ async def store_app_emoji(
285316
return emoji
286317

287318
async def get_all_emojis(self) -> list[GuildEmoji | AppEmoji]:
288-
return list(self._emojis.values())
319+
return self._flatten(list(self._emojis.values()))
289320

290321
async def get_emoji(self, emoji_id: int | None) -> GuildEmoji | AppEmoji | None:
291-
return self._emojis.get(emoji_id)
322+
emojis = self._flatten(list(self._emojis.values()))
323+
for emoji in emojis:
324+
if emoji.id == emoji_id:
325+
return emoji
292326

293327
async def delete_emoji(self, emoji: GuildEmoji | AppEmoji) -> None:
294328
if isinstance(emoji, AppEmoji):
@@ -354,3 +388,23 @@ async def get_all_messages(self) -> list[Message]:
354388

355389
async def delete_modal(self, custom_id: str) -> None:
356390
self._modals.pop(custom_id, None)
391+
392+
# guild members
393+
394+
async def store_member(self, member: Member) -> None:
395+
self._guild_members[member.guild.id][member.id] = member
396+
397+
async def get_member(self, guild_id: int, user_id: int) -> Member | None:
398+
return self._guild_members[guild_id].get(user_id)
399+
400+
async def delete_member(self, guild_id: int, user_id: int) -> None:
401+
self._guild_members[guild_id].pop(user_id, None)
402+
403+
async def delete_guild_members(self, guild_id: int) -> None:
404+
self._guild_members.pop(guild_id, None)
405+
406+
async def get_guild_members(self, guild_id: int) -> list[Member]:
407+
return list(self._guild_members.get(guild_id, {}).values())
408+
409+
async def get_all_members(self) -> list[Member]:
410+
return self._flatten([list(members.values()) for members in self._guild_members.values()])

discord/app/state.py

Lines changed: 3 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ async def add_members(self, members: list[Member]) -> None:
126126
return
127127

128128
for member in members:
129-
existing = guild.get_member(member.id)
129+
existing = await guild.get_member(member.id)
130130
if existing is None or existing.joined_at is None:
131-
guild._add_member(member)
131+
await guild._add_member(member)
132132

133133
async def wait(self) -> list[Member]:
134134
future = self.loop.create_future()
@@ -524,174 +524,6 @@ async def query_members(
524524
)
525525
raise
526526

527-
def parse_presence_update(self, data) -> None:
528-
guild_id = utils._get_as_snowflake(data, "guild_id")
529-
# guild_id won't be None here
530-
guild = self._get_guild(guild_id)
531-
if guild is None:
532-
_log.debug(
533-
"PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.",
534-
guild_id,
535-
)
536-
return
537-
538-
user = data["user"]
539-
member_id = int(user["id"])
540-
member = guild.get_member(member_id)
541-
if member is None:
542-
_log.debug(
543-
"PRESENCE_UPDATE referencing an unknown member ID: %s. Discarding",
544-
member_id,
545-
)
546-
return
547-
548-
old_member = Member._copy(member)
549-
user_update = member._presence_update(data=data, user=user)
550-
if user_update:
551-
self.dispatch("user_update", user_update[0], user_update[1])
552-
553-
self.dispatch("presence_update", old_member, member)
554-
555-
def parse_user_update(self, data) -> None:
556-
# self.user is *always* cached when this is called
557-
user: ClientUser = self.user # type: ignore
558-
user._update(data)
559-
ref = self._users.get(user.id)
560-
if ref:
561-
ref._update(data)
562-
563-
def parse_invite_create(self, data) -> None:
564-
invite = Invite.from_gateway(state=self, data=data)
565-
self.dispatch("invite_create", invite)
566-
567-
def parse_invite_delete(self, data) -> None:
568-
invite = Invite.from_gateway(state=self, data=data)
569-
self.dispatch("invite_delete", invite)
570-
571-
def parse_channel_delete(self, data) -> None:
572-
guild = self._get_guild(utils._get_as_snowflake(data, "guild_id"))
573-
channel_id = int(data["id"])
574-
if guild is not None:
575-
channel = guild.get_channel(channel_id)
576-
if channel is not None:
577-
guild._remove_channel(channel)
578-
self.dispatch("guild_channel_delete", channel)
579-
580-
def parse_channel_update(self, data) -> None:
581-
channel_type = try_enum(ChannelType, data.get("type"))
582-
channel_id = int(data["id"])
583-
if channel_type is ChannelType.group:
584-
channel = self._get_private_channel(channel_id)
585-
old_channel = copy.copy(channel)
586-
# the channel is a GroupChannel
587-
channel._update_group(data) # type: ignore
588-
self.dispatch("private_channel_update", old_channel, channel)
589-
return
590-
591-
guild_id = utils._get_as_snowflake(data, "guild_id")
592-
guild = self._get_guild(guild_id)
593-
if guild is not None:
594-
channel = guild.get_channel(channel_id)
595-
if channel is not None:
596-
old_channel = copy.copy(channel)
597-
channel._update(guild, data)
598-
self.dispatch("guild_channel_update", old_channel, channel)
599-
else:
600-
_log.debug(
601-
"CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.",
602-
channel_id,
603-
)
604-
else:
605-
_log.debug(
606-
"CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.",
607-
guild_id,
608-
)
609-
610-
def parse_channel_create(self, data) -> None:
611-
factory, ch_type = _channel_factory(data["type"])
612-
if factory is None:
613-
_log.debug(
614-
"CHANNEL_CREATE referencing an unknown channel type %s. Discarding.",
615-
data["type"],
616-
)
617-
return
618-
619-
guild_id = utils._get_as_snowflake(data, "guild_id")
620-
guild = self._get_guild(guild_id)
621-
if guild is not None:
622-
# the factory can't be a DMChannel or GroupChannel here
623-
channel = factory(guild=guild, state=self, data=data) # type: ignore
624-
guild._add_channel(channel) # type: ignore
625-
self.dispatch("guild_channel_create", channel)
626-
else:
627-
_log.debug(
628-
"CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.",
629-
guild_id,
630-
)
631-
return
632-
633-
def parse_channel_pins_update(self, data) -> None:
634-
channel_id = int(data["channel_id"])
635-
try:
636-
guild = self._get_guild(int(data["guild_id"]))
637-
except KeyError:
638-
guild = None
639-
channel = self._get_private_channel(channel_id)
640-
else:
641-
channel = guild and guild._resolve_channel(channel_id)
642-
643-
if channel is None:
644-
_log.debug(
645-
(
646-
"CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s."
647-
" Discarding."
648-
),
649-
channel_id,
650-
)
651-
return
652-
653-
last_pin = (
654-
utils.parse_time(data["last_pin_timestamp"])
655-
if data["last_pin_timestamp"]
656-
else None
657-
)
658-
659-
if guild is None:
660-
self.dispatch("private_channel_pins_update", channel, last_pin)
661-
else:
662-
self.dispatch("guild_channel_pins_update", channel, last_pin)
663-
664-
def parse_thread_create(self, data) -> None:
665-
guild_id = int(data["guild_id"])
666-
guild: Guild | None = self._get_guild(guild_id)
667-
if guild is None:
668-
_log.debug(
669-
"THREAD_CREATE referencing an unknown guild ID: %s. Discarding",
670-
guild_id,
671-
)
672-
return
673-
674-
cached_thread = guild.get_thread(int(data["id"]))
675-
if not cached_thread:
676-
thread = Thread(guild=guild, state=guild._state, data=data)
677-
guild._add_thread(thread)
678-
if data.get("newly_created"):
679-
thread._add_member(
680-
ThreadMember(
681-
thread,
682-
{
683-
"id": thread.id,
684-
"user_id": data["owner_id"],
685-
"join_timestamp": data["thread_metadata"][
686-
"create_timestamp"
687-
],
688-
"flags": utils.MISSING,
689-
},
690-
)
691-
)
692-
self.dispatch("thread_create", thread)
693-
else:
694-
self.dispatch("thread_join", cached_thread)
695527

696528
def parse_thread_update(self, data) -> None:
697529
guild_id = int(data["guild_id"])
@@ -1541,7 +1373,7 @@ async def _get_reaction_user(
15411373
self, channel: MessageableChannel, user_id: int
15421374
) -> User | Member | None:
15431375
if isinstance(channel, TextChannel):
1544-
return channel.guild.get_member(user_id)
1376+
return await channel.guild.get_member(user_id)
15451377
return await self.get_user(user_id)
15461378

15471379
async def get_reaction_emoji(self, data) -> GuildEmoji | AppEmoji | PartialEmoji:

discord/audit_logs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -603,8 +603,8 @@ def _from_data(self, data: AuditLogEntryPayload) -> None:
603603
self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore
604604
self._target_id = utils._get_as_snowflake(data, "target_id")
605605

606-
def _get_member(self, user_id: int) -> Member | User | None:
607-
return self.guild.get_member(user_id) or self._users.get(user_id)
606+
async def _get_member(self, user_id: int) -> Member | User | None:
607+
return await self.guild.get_member(user_id) or self._users.get(user_id)
608608

609609
def __repr__(self) -> str:
610610
return f"<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>"
@@ -667,8 +667,8 @@ def _convert_target_guild(self, target_id: int) -> Guild:
667667
def _convert_target_channel(self, target_id: int) -> abc.GuildChannel | Object:
668668
return self.guild.get_channel(target_id) or Object(id=target_id)
669669

670-
def _convert_target_user(self, target_id: int) -> Member | User | None:
671-
return self._get_member(target_id)
670+
async def _convert_target_user(self, target_id: int) -> Member | User | None:
671+
return await self._get_member(target_id)
672672

673673
def _convert_target_role(self, target_id: int) -> Role | Object:
674674
return self.guild.get_role(target_id) or Object(id=target_id)
@@ -700,8 +700,8 @@ def _convert_target_invite(self, target_id: int) -> Invite:
700700
async def _convert_target_emoji(self, target_id: int) -> GuildEmoji | Object:
701701
return (await self._state.get_emoji(target_id)) or Object(id=target_id)
702702

703-
def _convert_target_message(self, target_id: int) -> Member | User | None:
704-
return self._get_member(target_id)
703+
async def _convert_target_message(self, target_id: int) -> Member | User | None:
704+
return await self._get_member(target_id)
705705

706706
def _convert_target_stage_instance(self, target_id: int) -> StageInstance | Object:
707707
return self.guild.get_stage_instance(target_id) or Object(id=target_id)

0 commit comments

Comments
 (0)