Skip to content

Commit 4c98a78

Browse files
committed
🚧 Better resolved data and Interaction inheritance, app commands wip
1 parent dd05817 commit 4c98a78

File tree

5 files changed

+209
-45
lines changed

5 files changed

+209
-45
lines changed

discord/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def __init__(
295295
self._tasks = set()
296296

297297
async def _gather_events(self, event: Event) -> None:
298-
await asyncio.gather(*super()._handle_event(event))
298+
await asyncio.gather(*self._handle_event(event))
299299

300300
async def __aenter__(self) -> Client:
301301
loop = asyncio.get_running_loop()

discord/enums.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,12 @@ class ApplicationCommandPermissionType(Enum):
987987
channel = 3
988988

989989

990+
class ApplicationCommandType(IntEnum):
991+
CHAT_INPUT = 1
992+
USER = 2
993+
MESSAGE = 3
994+
995+
990996
def try_enum(cls: type[E], val: Any) -> E:
991997
"""A function that tries to turn the value into enum ``cls``.
992998
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
The MIT License (MIT)
3+
4+
Copyright (c) 2021-present Pycord Development
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a
7+
copy of this software and associated documentation files (the "Software"),
8+
to deal in the Software without restriction, including without limitation
9+
the rights to use, copy, modify, merge, publish, distribute, sublicense,
10+
and/or sell copies of the Software, and to permit persons to whom the
11+
Software is furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in
14+
all copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
17+
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21+
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22+
DEALINGS IN THE SOFTWARE.
23+
"""
24+
25+
from abc import ABC
26+
from functools import wraps
27+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Generic, ParamSpec, Protocol, TypeAlias, TypeVar
28+
29+
from typing_extensions import Unpack
30+
31+
from ..enums import ApplicationCommandType
32+
from ..events import InteractionCreate
33+
from ..interactions import ApplicationCommandInteraction
34+
from ..utils import MISSING, Undefined
35+
from ..utils.private import hybridmethod, maybe_awaitable
36+
from .base import GearBase
37+
38+
if TYPE_CHECKING:
39+
from ..commands import ApplicationCommand
40+
P = ParamSpec("P")
41+
R = TypeVar("R")
42+
43+
44+
class CommandListener(Protocol, Generic[P, R]):
45+
__command__: "ApplicationCommand"
46+
47+
async def __call__(self, interaction: ApplicationCommandInteraction, *args: P.args, **kwargs: P.kwargs) -> R: ...
48+
49+
50+
def _listener_factory(listener: CommandListener, command_name: str) -> Callable[..., ...]:
51+
@wraps(listener)
52+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | None:
53+
# Assume last positional arg is the interaction
54+
if args:
55+
interaction: Any = args[-1]
56+
if isinstance(interaction, ApplicationCommandInteraction) and interaction.command_name == command_name:
57+
interaction.command = listener.__command__
58+
if interaction.command_type == ApplicationCommandType.CHAT_INPUT:
59+
...
60+
elif interaction.command_type == ApplicationCommandType.USER:
61+
...
62+
63+
return await listener(*args, **kwargs)
64+
return None
65+
66+
return wrapper
67+
68+
69+
ACG_t = TypeVar("ACG_t", bound="ApplicationCommandsGearMixin")
70+
71+
72+
class ApplicationCommandsGearMixin(GearBase, ABC):
73+
"""A mixin that provides application commands handling for a :class:`discord.Gear`.
74+
75+
This mixin is used to handle application commands interactions.
76+
"""

discord/guild.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ async def _get_and_update_member(self, payload: MemberPayload, user_id: int, cac
332332
# flag should always be MemberCacheFlag.interaction) is set to True
333333
if user_id in members:
334334
member = cast(Member, await self.get_member(user_id))
335-
await member._update(payload) if cache_flag else None
335+
await member._update(payload) if cache_flag else None # TODO: This is being cached incorrectly @VincentRPS
336336
else:
337337
# NOTE:
338338
# This is a fallback in case the member is not found in the guild's members.
@@ -387,7 +387,6 @@ def __repr__(self) -> str:
387387
("id", self.id),
388388
("name", self.name),
389389
("shard_id", self.shard_id),
390-
("chunked", self.chunked),
391390
("member_count", self._member_count),
392391
)
393392
inner = " ".join("%s=%r" % t for t in attrs)

discord/interactions.py

Lines changed: 125 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,21 @@
2828
import asyncio
2929
import datetime
3030
from collections.abc import Sequence
31-
from typing import TYPE_CHECKING, Any, Coroutine, Generic, cast
31+
from typing import TYPE_CHECKING, Any, Coroutine, Generic, Protocol, cast
3232

3333
from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack, override, reveal_type
3434

3535
from . import utils
3636
from .channel import PartialMessageable, _threaded_channel_factory
3737
from .components import ComponentsHolder, _partial_component_factory
38-
from .enums import ChannelType, InteractionContextType, InteractionResponseType, InteractionType, try_enum
38+
from .enums import (
39+
ApplicationCommandType,
40+
ChannelType,
41+
InteractionContextType,
42+
InteractionResponseType,
43+
InteractionType,
44+
try_enum,
45+
)
3946
from .errors import ClientException, InteractionResponded, InvalidArgument
4047
from .file import File, VoiceMessage
4148
from .flags import MessageFlags
@@ -50,6 +57,9 @@
5057
ApplicationCommandAutocompleteInteraction as ApplicationCommandAutocompleteInteractionPayload,
5158
)
5259
from .types.interactions import ApplicationCommandInteraction as ApplicationCommandInteractionPayload
60+
from .types.interactions import (
61+
ApplicationCommandInteractionDataOption,
62+
)
5363
from .types.interactions import Interaction as InteractionPayload
5464
from .user import User
5565
from .utils import find
@@ -99,6 +109,7 @@
99109
from .embeds import Embed
100110
from .mentions import AllowedMentions
101111
from .poll import Poll
112+
from .types.interactions import ComponentInteraction as ComponentInteractionPayload
102113
from .types.interactions import InteractionCallback as InteractionCallbackPayload
103114
from .types.interactions import InteractionCallbackResponse, InteractionData
104115
from .types.interactions import InteractionData as InteractionDataPayload
@@ -231,6 +242,7 @@ def __init__(self, *, payload: InteractionPayload, state: ConnectionState):
231242
self._original_response: InteractionMessage | None = None
232243
self.data = payload.get("data")
233244
self.callback: InteractionCallback | None = None
245+
super().__init__()
234246

235247
@classmethod
236248
async def _from_data(cls, payload: InteractionPayload, state: ConnectionState) -> Self:
@@ -248,7 +260,6 @@ async def _from_data(cls, payload: InteractionPayload, state: ConnectionState) -
248260
self.application_id: int = int(self._payload["application_id"])
249261
self.locale: str | None = self._payload.get("locale")
250262
self.guild_locale: str | None = self._payload.get("guild_locale")
251-
self.custom_id: str | None = self.data.get("custom_id") if self.data is not None else None
252263
self._app_permissions: int = int(self._payload.get("app_permissions", 0))
253264
self.entitlements: list[Entitlement] = [
254265
Entitlement(data=e, state=self._state) for e in self._payload.get("entitlements", [])
@@ -693,9 +704,9 @@ def to_dict(self) -> dict[str, Any]:
693704
U = TypeVar("U", bound="ApplicationCommandInteractionPayload | ApplicationCommandAutocompleteInteractionPayload")
694705

695706

696-
class _CommandBoundInteraction(Interaction[U], Generic[U]):
697-
def __init__(self, *, payload: U, state: ConnectionState):
698-
super().__init__(payload=payload, state=state)
707+
class _CommandBoundInteractionMixin:
708+
def __init__(self, *args: Any, **kwargs: Any) -> None:
709+
super().__init__(*args, **kwargs)
699710
self._command: ApplicationCommand | None = None
700711

701712
@property
@@ -707,10 +718,102 @@ def command(self) -> ApplicationCommand:
707718
return self._command
708719

709720

710-
class ApplicationCommandInteraction(_CommandBoundInteraction[ApplicationCommandInteractionPayload]): ...
721+
class _ResolvedDataInteraction(Interaction[T], Generic[T]):
722+
"""A mixin that loads and parses the resolved data from an interaction payload."""
723+
724+
__slots__: tuple[str, ...] = (
725+
"users",
726+
"members",
727+
"roles",
728+
"channels",
729+
"messages",
730+
"attachments",
731+
)
732+
733+
@override
734+
@classmethod
735+
async def _from_data(cls, payload: InteractionPayload, state: ConnectionState) -> Self:
736+
self = await super()._from_data(payload=payload, state=state)
737+
resolved = self._payload.get("data", {}).get("resolved", {})
738+
if users := resolved.get("users"):
739+
self.users: dict[int, User] = {
740+
int(user_id): User(state=state, data=user_data) for user_id, user_data in users.items()
741+
}
742+
if (members := resolved.get("members")) and (guild := await self.get_guild()):
743+
self.members: dict[int, Member] = {}
744+
for member_id, member_data in members.items():
745+
member_data["id"] = int(member_id)
746+
member_data["user"] = resolved["users"][member_id]
747+
self.members[member_data["id"]] = await guild._get_and_update_member(
748+
member_data, member_data["id"], self._state.member_cache_flags.interaction
749+
)
750+
if roles := resolved.get("roles"):
751+
self.roles: dict[int, Role] = {
752+
int(role_id): Role(state=state, data=role_data, guild=self.guild)
753+
for role_id, role_data in roles.items()
754+
}
755+
if channels := resolved.get("channels"): # noqa: F841 see below
756+
# TODO: Partial channels @Paillat-dev
757+
self.channels: dict[int, InteractionChannel] = {}
758+
if messages := resolved.get("messages"):
759+
self.messages: dict[int, Message] = {}
760+
for message_id, message_data in messages.items():
761+
channel = self.channel
762+
if channel.id != int(message_data["channel_id"]):
763+
# we got weird stuff going on, make up a channel
764+
channel = PartialMessageable(state=self._state, id=int(message_data["channel_id"]))
765+
766+
self.messages[int(message_id)] = await Message._from_data(
767+
state=self._state, channel=channel, data=message_data
768+
)
769+
if attachments := resolved.get("attachments"):
770+
self.attachments: dict[int, Attachment] = {
771+
int(att_id): Attachment(state=state, data=att_data) for att_id, att_data in attachments.items()
772+
}
773+
return self
774+
775+
776+
class ApplicationCommandInteraction(
777+
_ResolvedDataInteraction[ApplicationCommandInteractionPayload], _CommandBoundInteractionMixin
778+
):
779+
__slots__: tuple[str, ...] = ("_command",)
780+
781+
def __init__(self, *, payload: ApplicationCommandInteractionPayload, state: ConnectionState):
782+
super().__init__(payload=payload, state=state)
783+
if self.data is None: # TODO: make it so that this can never be None @Paillat-dev
784+
raise RuntimeError("This interaction has no data associated with it.")
785+
self.command_name = self._parse_command_name(self.data["name"], self.data.get("options", []))
786+
self.command_type: ApplicationCommandType = self.data["type"]
787+
self.guild_id: int | None = self.data.get("guild_id")
788+
# self.options: list[ApplicationCommandInteractionDataOption] = self.data.get("options", [])
789+
self.target: User | Member | Message | None = None
790+
self._target_id: int | None = None
791+
self._command_type: ApplicationCommandType = self.data["type"]
792+
793+
def _parse_command_name(self, current_name: str, options: list[ApplicationCommandInteractionDataOption]) -> str:
794+
if options and (child_options := options[0].get("options")):
795+
current_name += " " + options[0]["name"]
796+
return self._parse_command_name(current_name, child_options)
797+
return current_name
798+
799+
@override
800+
@classmethod
801+
async def _from_data(cls, payload: ApplicationCommandInteractionPayload, state: ConnectionState) -> Self: # ty:ignore[invalid-method-override]
802+
self: ApplicationCommandInteraction = await super()._from_data(payload=payload, state=state)
803+
if self._command_type == ApplicationCommandType.CHAT_INPUT:
804+
...
805+
else:
806+
self._target_id = int(self.data["target_id"])
807+
if self._command_type == ApplicationCommandType.USER:
808+
self.target = self.users[self._target_id]
809+
elif self._command_type == ApplicationCommandType.MESSAGE:
810+
self.target = self.messages[self._target_id]
811+
return self
711812

712813

713-
class AutocompleteInteraction(_CommandBoundInteraction[ApplicationCommandAutocompleteInteractionPayload]):
814+
class AutocompleteInteraction(
815+
Interaction[ApplicationCommandAutocompleteInteractionPayload], _CommandBoundInteractionMixin
816+
):
714817
def __init__(self, *, payload: ApplicationCommandAutocompleteInteractionPayload, state: ConnectionState):
715818
super().__init__(payload=payload, state=state)
716819
options = self.data.get("options", [])
@@ -725,50 +828,30 @@ def __init__(self, *, payload: ApplicationCommandAutocompleteInteractionPayload,
725828
Components_t = TypeVarTuple("Components_t", default="Unpack[tuple[AnyTopLevelModalPartialComponent, ...]]")
726829

727830

728-
class ModalInteraction(Interaction, Generic[Unpack[Components_t]]):
729-
__slots__ = ("_components", "users", "attachments", "roles")
831+
class ModalInteraction(_ResolvedDataInteraction["ModalInteractionPayload"], Generic[Unpack[Components_t]]):
832+
__slots__ = ("components", "custom_id")
730833

834+
@override
731835
def __init__(self, *, payload: ModalInteractionPayload, state: ConnectionState):
732836
super().__init__(payload=payload, state=state)
733-
resolved = payload.get("data", {}).get("resolved", {})
734-
self.users: dict[int, User] = {
735-
int(user_id): User(state=state, data=user_data) for user_id, user_data in resolved.get("users", {}).items()
736-
}
737-
self.attachments: dict[int, Attachment] = {
738-
int(att_id): Attachment(state=state, data=att_data)
739-
for att_id, att_data in resolved.get("attachments", {}).items()
740-
}
741-
self.roles: dict[int, Role] = {
742-
int(role_id): Role(state=state, data=role_data, guild=self.guild)
743-
for role_id, role_data in resolved.get("roles", {}).items()
744-
}
745-
746-
# TODO: When we have better partial objects, add self.channels and self.members
747-
748-
@cached_slot_property("_components")
749-
def components(self) -> ComponentsHolder[Unpack[Components_t]]:
750-
if not self.type == InteractionType.modal_submit:
751-
raise TypeError("Only modal submit interactions have components")
752-
if not self.data:
753-
raise TypeError("This interaction has no data. This should never happen, please open an issue on GitHub")
837+
self.custom_id: str = self.data["custom_id"]
754838
components_payload = cast("list[PartialComponent]", self.data.get("components", []))
755-
756-
return ComponentsHolder(*(_partial_component_factory(component) for component in components_payload)) # pyright: ignore[reportReturnType]
839+
self.components: ComponentsHolder[Unpack[Components_t]] = ComponentsHolder(
840+
*(_partial_component_factory(component) for component in components_payload)
841+
)
757842

758843

759844
Component_t = TypeVar("Component_t", bound="AnyMessagePartialComponent", default="AnyMessagePartialComponent")
760845

761846

762-
class ComponentInteraction(Interaction, Generic[Component_t]):
763-
__slots__ = ("_component",)
847+
class ComponentInteraction(_ResolvedDataInteraction["ComponentInteractionPayload"], Generic[Component_t]):
848+
__slots__ = ("component", "custom_id")
764849

765-
@cached_slot_property("_component")
766-
def component(self) -> Component_t:
767-
if not self.type == InteractionType.component:
768-
raise TypeError("Only component interactions have a component")
769-
if not self.data:
770-
raise TypeError("This interaction has no data. This should never happen, please open an issue on GitHub")
771-
return _partial_component_factory(self.data, key="component_type") # pyright: ignore[reportArgumentType, reportReturnType]
850+
@override
851+
def __init__(self, *, payload: ComponentInteractionPayload, state: ConnectionState):
852+
super().__init__(payload=payload, state=state)
853+
self.custom_id: str = self.data["custom_id"]
854+
self.component: Component_t = _partial_component_factory(self.data, key="component_type")
772855

773856

774857
class InteractionResponse:

0 commit comments

Comments
 (0)