Skip to content

Commit 97dac81

Browse files
committed
refactor!: another bulk of changes
I need to make more commits
1 parent 0d0ac28 commit 97dac81

File tree

22 files changed

+601
-259
lines changed

22 files changed

+601
-259
lines changed

discord/app/cache.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,15 +211,15 @@ async def get_user(self, user_id: int) -> User:
211211
async def get_all_stickers(self) -> list[GuildSticker]:
212212
return list(self._stickers.values())
213213

214-
async def get_sticker(self, sticker_id: int) -> GuildSticker:
214+
async def get_sticker(self, sticker_id: int) -> GuildSticker | None:
215215
return self._stickers.get(sticker_id)
216216

217217
async def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker:
218218
sticker = GuildSticker(state=self._state, data=data)
219219
try:
220220
self._stickers[guild.id].append(sticker)
221221
except KeyError:
222-
self._stickers[guild.id] = sticker
222+
self._stickers[guild.id] = [sticker]
223223
return sticker
224224

225225
async def delete_sticker(self, sticker_id: int) -> None:
@@ -228,10 +228,12 @@ async def delete_sticker(self, sticker_id: int) -> None:
228228
# interactions
229229

230230
async def delete_view_on(self, message_id: int) -> View | None:
231-
return self._views.pop(message_id, None)
231+
for view in await self.get_all_views():
232+
if view.message and view.message.id == message_id:
233+
return view
232234

233235
async def store_view(self, view: View, message_id: int) -> None:
234-
self._views[message_id or view.id] = view
236+
self._views[str(message_id or view.id)] = view
235237

236238
async def get_all_views(self) -> list[View]:
237239
return list(self._views.values())

discord/app/events.py renamed to discord/app/event_emitter.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@
2929

3030
from .state import ConnectionState
3131

32-
T = TypeVar('T')
32+
T = TypeVar('T', bound='Event')
3333

3434

3535
class Event(ABC):
3636
__event_name__: str
3737

3838
@classmethod
39-
async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self:
39+
async def __load__(cls, data: Any, state: ConnectionState) -> Self | None:
4040
...
4141

4242

@@ -61,7 +61,7 @@ def add_listener(self, event: Type[Event], listener: Callable) -> None:
6161
self._listeners[event].append(listener)
6262
except KeyError:
6363
self.add_event(event)
64-
self._listener[event] = [listener]
64+
self._listeners[event] = [listener]
6565

6666
def remove_listener(self, event: Type[Event], listener: Callable) -> None:
6767
self._listeners[event].remove(listener)
@@ -79,21 +79,23 @@ def add_wait_for(self, event: Type[T]) -> Future[T]:
7979
def remove_wait_for(self, event: Type[Event], fut: Future) -> None:
8080
self._wait_fors[event].remove(fut)
8181

82-
async def publish(self, event_str: str, data: dict[str, Any]) -> None:
83-
items = list(self.events.items())
82+
async def emit(self, event_str: str, data: Any) -> None:
83+
events = self._events.get(event_str, [])
8484

85-
for event, funcs in items:
86-
if event._name == event_str:
87-
eve = event()
85+
for event in events:
86+
eve = await event.__load__(data=data, state=self._state)
8887

89-
await eve.__load__(data)
88+
if eve is None:
89+
continue
9090

91-
for func in funcs:
92-
asyncio.create_task(func(eve))
91+
funcs = self._listeners.get(event, [])
9392

94-
wait_fors = self.wait_fors.get(event)
93+
for func in funcs:
94+
asyncio.create_task(func(eve))
9595

96-
if wait_fors is not None:
97-
for wait_for in wait_fors:
98-
wait_for.set_result(eve)
99-
self.wait_fors.pop(event)
96+
wait_fors = self._wait_fors.get(event)
97+
98+
if wait_fors is not None:
99+
for wait_for in wait_fors:
100+
wait_for.set_result(eve)
101+
self._wait_fors.pop(event)

discord/app/state.py

Lines changed: 6 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
Union,
4444
)
4545

46+
from discord.app.event_emitter import EventEmitter
47+
4648
from .cache import Cache
4749

4850
from .. import utils
@@ -186,6 +188,7 @@ def __init__(
186188
self.application_id: int | None = utils._get_as_snowflake(
187189
options, "application_id"
188190
)
191+
self.application_flags: ApplicationFlags | None = None
189192
self.heartbeat_timeout: float = options.get("heartbeat_timeout", 60.0)
190193
self.guild_ready_timeout: float = options.get("guild_ready_timeout", 2.0)
191194
if self.guild_ready_timeout < 0:
@@ -256,10 +259,7 @@ def __init__(
256259

257260
self.cache_app_emojis: bool = options.get("cache_app_emojis", False)
258261

259-
self.parsers = parsers = {}
260-
for attr, func in inspect.getmembers(self):
261-
if attr.startswith("parse_"):
262-
parsers[attr[6:].upper()] = func
262+
self.emitter = EventEmitter(self)
263263

264264
self.cache: Cache = self.cache
265265

@@ -268,13 +268,13 @@ def clear(self, *, views: bool = True) -> None:
268268
self.cache.clear()
269269
self._voice_clients: dict[int, VoiceClient] = {}
270270

271-
def process_chunk_requests(
271+
async def process_chunk_requests(
272272
self, guild_id: int, nonce: str | None, members: list[Member], complete: bool
273273
) -> None:
274274
removed = []
275275
for key, request in self._chunk_requests.items():
276276
if request.guild_id == guild_id and request.nonce == nonce:
277-
request.add_members(members)
277+
await request.add_members(members)
278278
if complete:
279279
request.done()
280280
removed.append(key)
@@ -524,175 +524,6 @@ async def query_members(
524524
)
525525
raise
526526

527-
async def _delay_ready(self) -> None:
528-
529-
if self.cache_app_emojis and self.application_id:
530-
data = await self.http.get_all_application_emojis(self.application_id)
531-
for e in data.get("items", []):
532-
await self.maybe_store_app_emoji(self.application_id, e)
533-
try:
534-
states = []
535-
while True:
536-
# this snippet of code is basically waiting N seconds
537-
# until the last GUILD_CREATE was sent
538-
try:
539-
guild = await asyncio.wait_for(
540-
self._ready_state.get(), timeout=self.guild_ready_timeout
541-
)
542-
except asyncio.TimeoutError:
543-
break
544-
else:
545-
if self._guild_needs_chunking(guild):
546-
future = await self.chunk_guild(guild, wait=False)
547-
states.append((guild, future))
548-
elif guild.unavailable is False:
549-
self.dispatch("guild_available", guild)
550-
else:
551-
self.dispatch("guild_join", guild)
552-
553-
for guild, future in states:
554-
try:
555-
await asyncio.wait_for(future, timeout=5.0)
556-
except asyncio.TimeoutError:
557-
_log.warning(
558-
"Shard ID %s timed out waiting for chunks for guild_id %s.",
559-
guild.shard_id,
560-
guild.id,
561-
)
562-
563-
if guild.unavailable is False:
564-
self.dispatch("guild_available", guild)
565-
else:
566-
self.dispatch("guild_join", guild)
567-
568-
# remove the state
569-
try:
570-
del self._ready_state
571-
except AttributeError:
572-
pass # already been deleted somehow
573-
574-
except asyncio.CancelledError:
575-
pass
576-
else:
577-
# dispatch the event
578-
self.call_handlers("ready")
579-
self.dispatch("ready")
580-
finally:
581-
self._ready_task = None
582-
583-
def parse_ready(self, data) -> None:
584-
if self._ready_task is not None:
585-
self._ready_task.cancel()
586-
587-
self._ready_state = asyncio.Queue()
588-
self.clear(views=False)
589-
self.user = ClientUser(state=self, data=data["user"])
590-
self.store_user(data["user"])
591-
592-
if self.application_id is None:
593-
try:
594-
application = data["application"]
595-
except KeyError:
596-
pass
597-
else:
598-
self.application_id = utils._get_as_snowflake(application, "id")
599-
# flags will always be present here
600-
self.application_flags = ApplicationFlags._from_value(application["flags"]) # type: ignore
601-
602-
for guild_data in data["guilds"]:
603-
self._add_guild_from_data(guild_data)
604-
605-
self.dispatch("connect")
606-
self._ready_task = asyncio.create_task(self._delay_ready())
607-
608-
def parse_resumed(self, data) -> None:
609-
self.dispatch("resumed")
610-
611-
def parse_application_command_permissions_update(self, data) -> None:
612-
# unsure what the implementation would be like
613-
pass
614-
615-
def parse_auto_moderation_rule_create(self, data) -> None:
616-
rule = AutoModRule(state=self, data=data)
617-
self.dispatch("auto_moderation_rule_create", rule)
618-
619-
def parse_auto_moderation_rule_update(self, data) -> None:
620-
# somehow get a 'before' object?
621-
rule = AutoModRule(state=self, data=data)
622-
self.dispatch("auto_moderation_rule_update", rule)
623-
624-
def parse_auto_moderation_rule_delete(self, data) -> None:
625-
rule = AutoModRule(state=self, data=data)
626-
self.dispatch("auto_moderation_rule_delete", rule)
627-
628-
def parse_auto_moderation_action_execution(self, data) -> None:
629-
event = AutoModActionExecutionEvent(self, data)
630-
self.dispatch("auto_moderation_action_execution", event)
631-
632-
def parse_entitlement_create(self, data) -> None:
633-
event = Entitlement(data=data, state=self)
634-
self.dispatch("entitlement_create", event)
635-
636-
def parse_entitlement_update(self, data) -> None:
637-
event = Entitlement(data=data, state=self)
638-
self.dispatch("entitlement_update", event)
639-
640-
def parse_entitlement_delete(self, data) -> None:
641-
event = Entitlement(data=data, state=self)
642-
self.dispatch("entitlement_delete", event)
643-
644-
def parse_subscription_create(self, data) -> None:
645-
event = Subscription(data=data, state=self)
646-
self.dispatch("subscription_create", event)
647-
648-
def parse_subscription_update(self, data) -> None:
649-
event = Subscription(data=data, state=self)
650-
self.dispatch("subscription_update", event)
651-
652-
def parse_subscription_delete(self, data) -> None:
653-
event = Subscription(data=data, state=self)
654-
self.dispatch("subscription_delete", event)
655-
656-
def parse_message_create(self, data) -> None:
657-
channel, _ = self._get_guild_channel(data)
658-
# channel would be the correct type here
659-
message = Message(channel=channel, data=data, state=self) # type: ignore
660-
self.dispatch("message", message)
661-
if self._messages is not None:
662-
self._messages.append(message)
663-
# we ensure that the channel is either a TextChannel, VoiceChannel, StageChannel, or Thread
664-
if channel and channel.__class__ in (
665-
TextChannel,
666-
VoiceChannel,
667-
StageChannel,
668-
Thread,
669-
):
670-
channel.last_message_id = message.id # type: ignore
671-
672-
def parse_message_delete(self, data) -> None:
673-
raw = RawMessageDeleteEvent(data)
674-
found = self._get_message(raw.message_id)
675-
raw.cached_message = found
676-
self.dispatch("raw_message_delete", raw)
677-
if self._messages is not None and found is not None:
678-
self.dispatch("message_delete", found)
679-
self._messages.remove(found)
680-
681-
def parse_message_delete_bulk(self, data) -> None:
682-
raw = RawBulkMessageDeleteEvent(data)
683-
if self._messages:
684-
found_messages = [
685-
message for message in self._messages if message.id in raw.message_ids
686-
]
687-
else:
688-
found_messages = []
689-
raw.cached_messages = found_messages
690-
self.dispatch("raw_bulk_message_delete", raw)
691-
if found_messages:
692-
self.dispatch("bulk_message_delete", found_messages)
693-
for msg in found_messages:
694-
# self._messages won't be None here
695-
self._messages.remove(msg) # type: ignore
696527

697528
def parse_message_update(self, data) -> None:
698529
raw = RawMessageUpdateEvent(data)

discord/client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -879,10 +879,9 @@ def intents(self) -> Intents:
879879

880880
# helpers/getters
881881

882-
@property
883-
def users(self) -> list[User]:
882+
async def get_users(self) -> list[User]:
884883
"""Returns a list of all the users the bot can see."""
885-
return list(self._connection._users.values())
884+
return await self._connection.cache.get_all_users()
886885

887886
async def fetch_application(self, application_id: int, /) -> PartialAppInfo:
888887
"""|coro|

discord/commands/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None:
991991
# We resolved the user from the user id
992992
_data["user"] = _user_data
993993
cache_flag = ctx.interaction._state.member_cache_flags.interaction
994-
arg = ctx.guild._get_and_update_member(_data, int(arg), cache_flag)
994+
arg = await ctx.guild._get_and_update_member(_data, int(arg), cache_flag)
995995
elif op.input_type is SlashCommandOptionType.mentionable:
996996
if (_data := resolved.get("users", {}).get(arg)) is not None:
997997
arg = User(state=ctx.interaction._state, data=_data)
@@ -1787,7 +1787,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None:
17871787
user = v
17881788
member["user"] = user
17891789
cache_flag = ctx.interaction._state.member_cache_flags.interaction
1790-
target = ctx.guild._get_and_update_member(member, user["id"], cache_flag)
1790+
target = await ctx.guild._get_and_update_member(member, user["id"], cache_flag)
17911791
if self.cog is not None:
17921792
await self.callback(self.cog, ctx, target)
17931793
else:

discord/enums.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
"EntitlementOwnerType",
7777
"IntegrationType",
7878
"InteractionContextType",
79+
"ApplicationCommandPermissionType"
7980
)
8081

8182

@@ -1063,6 +1064,14 @@ class SubscriptionStatus(Enum):
10631064
inactive = 2
10641065

10651066

1067+
class ApplicationCommandPermissionType(Enum):
1068+
"""The type of permission"""
1069+
1070+
role = 1
1071+
user = 2
1072+
channel = 3
1073+
1074+
10661075
T = TypeVar("T")
10671076

10681077

0 commit comments

Comments
 (0)