Skip to content

Commit 6335944

Browse files
Add channel_types, should fix #203
1 parent 699849d commit 6335944

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

discord/app/commands.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525
from __future__ import annotations
2626

2727
import asyncio
28+
from discord.types.channel import TextChannel, VoiceChannel
2829
import functools
2930
import inspect
3031
from collections import OrderedDict
3132
from typing import Any, Callable, Dict, List, Optional, Union
3233

33-
from ..enums import SlashCommandOptionType
34+
from ..enums import SlashCommandOptionType, ChannelType
3435
from ..member import Member
3536
from ..user import User
3637
from ..message import Message
@@ -358,10 +359,15 @@ def parse_options(self, params) -> List[Option]:
358359
if option == inspect.Parameter.empty:
359360
option = str
360361

361-
if self._is_typing_optional(option):
362-
option = Option(
363-
option.__args__[0], "No description provided", required=False
364-
)
362+
if self._is_typing_union(option):
363+
if self._is_typing_optional(option):
364+
option = Option(
365+
option.__args__[0], "No description provided", required=False
366+
)
367+
else:
368+
option = Option(
369+
option.__args__, "No description provided"
370+
)
365371

366372
if not isinstance(option, Option):
367373
option = Option(option, "No description provided")
@@ -380,8 +386,11 @@ def parse_options(self, params) -> List[Option]:
380386

381387
return final_options
382388

389+
def _is_typing_union(self, annotation):
390+
return getattr(annotation, "__origin__", None) is Union # type: ignore
391+
383392
def _is_typing_optional(self, annotation):
384-
return getattr(annotation, "__origin__", None) is Union and type(None) in annotation.__args__ # type: ignore
393+
return self._is_typing_union(annotation) and type(None) in annotation.__args__ # type: ignore
385394

386395
def to_dict(self) -> Dict:
387396
as_dict = {
@@ -474,15 +483,33 @@ def _update_copy(self, kwargs: Dict[str, Any]):
474483
else:
475484
return self.copy()
476485

486+
channel_type_map = {
487+
'TextChannel': ChannelType.text,
488+
'VoiceChannel': ChannelType.voice,
489+
'StageChannel': ChannelType.stage_voice,
490+
'CategoryChannel': ChannelType.category
491+
}
492+
477493
class Option:
478494
def __init__(
479495
self, input_type: Any, /, description: str = None, **kwargs
480496
) -> None:
481497
self.name: Optional[str] = kwargs.pop("name", None)
482498
self.description = description or "No description provided"
499+
500+
self.channel_types = []
483501
if not isinstance(input_type, SlashCommandOptionType):
484-
input_type = SlashCommandOptionType.from_datatype(input_type)
485-
self.input_type = input_type
502+
self.input_type = SlashCommandOptionType.from_datatype(input_type)
503+
if self.input_type == SlashCommandOptionType.channel:
504+
input_type = (input_type,) if not isinstance(input_type, tuple) else input_type
505+
for i in input_type:
506+
if i.__name__ == 'GuildChannel':
507+
continue
508+
channel_type = channel_type_map[i.__name__].value
509+
self.channel_types.append(channel_type)
510+
else:
511+
self.input_type = input_type
512+
print(self.channel_types)
486513
self.required: bool = kwargs.pop("required", True)
487514
self.choices: List[OptionChoice] = [
488515
o if isinstance(o, OptionChoice) else OptionChoice(o)
@@ -491,13 +518,18 @@ def __init__(
491518
self.default = kwargs.pop("default", None)
492519

493520
def to_dict(self) -> Dict:
494-
return {
521+
as_dict = {
495522
"name": self.name,
496523
"description": self.description,
497524
"type": self.input_type.value,
498525
"required": self.required,
499526
"choices": [c.to_dict() for c in self.choices],
500527
}
528+
if self.channel_types:
529+
as_dict["channel_types"] = self.channel_types
530+
531+
return as_dict
532+
501533

502534
def __repr__(self):
503535
return f"<discord.app.commands.{self.__class__.__name__} name={self.name}>"

discord/enums.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,16 @@ class SlashCommandOptionType(Enum):
603603

604604
@classmethod
605605
def from_datatype(cls, datatype):
606+
607+
if isinstance(datatype, tuple): # typing.Union has been used
608+
datatypes = [cls.from_datatype(op) for op in datatype]
609+
if all([x == cls.channel for x in datatypes]):
610+
return cls.channel
611+
elif set(datatypes) <= {cls.role, cls.channel}:
612+
return cls.mentionable
613+
else:
614+
raise TypeError('Invalid usage of typing.Union')
615+
606616
if issubclass(datatype, str):
607617
return cls.string
608618
if issubclass(datatype, bool):
@@ -614,15 +624,19 @@ def from_datatype(cls, datatype):
614624

615625
if datatype.__name__ == "Member":
616626
return cls.user
617-
if datatype.__name__ == "GuildChannel":
627+
if datatype.__name__ in [
628+
"GuildChannel", "TextChannel",
629+
"VoiceChannel", "StageChannel",
630+
"CategoryChannel"
631+
]:
618632
return cls.channel
619633
if datatype.__name__ == "Role":
620634
return cls.role
621635
if datatype.__name__ == "Mentionable":
622636
return cls.mentionable
623-
637+
624638
# TODO: Improve the error message
625-
raise TypeError('Invalid class used as an input type for an Option')
639+
raise TypeError(f'Invalid class {datatype} used as an input type for an Option')
626640

627641
T = TypeVar('T')
628642

0 commit comments

Comments
 (0)