25
25
from __future__ import annotations
26
26
27
27
import asyncio
28
+ from discord .types .channel import TextChannel , VoiceChannel
28
29
import functools
29
30
import inspect
30
31
from collections import OrderedDict
31
32
from typing import Any , Callable , Dict , List , Optional , Union
32
33
33
- from ..enums import SlashCommandOptionType
34
+ from ..enums import SlashCommandOptionType , ChannelType
34
35
from ..member import Member
35
36
from ..user import User
36
37
from ..message import Message
@@ -358,10 +359,15 @@ def parse_options(self, params) -> List[Option]:
358
359
if option == inspect .Parameter .empty :
359
360
option = str
360
361
361
- if self ._is_typing_optional (option ):
362
- option = Option (
363
- option .__args__ [0 ], "No description provided" , required = False
364
- )
362
+ if self ._is_typing_union (option ):
363
+ if self ._is_typing_optional (option ):
364
+ option = Option (
365
+ option .__args__ [0 ], "No description provided" , required = False
366
+ )
367
+ else :
368
+ option = Option (
369
+ option .__args__ , "No description provided"
370
+ )
365
371
366
372
if not isinstance (option , Option ):
367
373
option = Option (option , "No description provided" )
@@ -380,8 +386,11 @@ def parse_options(self, params) -> List[Option]:
380
386
381
387
return final_options
382
388
389
+ def _is_typing_union (self , annotation ):
390
+ return getattr (annotation , "__origin__" , None ) is Union # type: ignore
391
+
383
392
def _is_typing_optional (self , annotation ):
384
- return getattr (annotation , "__origin__" , None ) is Union and type (None ) in annotation .__args__ # type: ignore
393
+ return self . _is_typing_union (annotation ) and type (None ) in annotation .__args__ # type: ignore
385
394
386
395
def to_dict (self ) -> Dict :
387
396
as_dict = {
@@ -474,15 +483,33 @@ def _update_copy(self, kwargs: Dict[str, Any]):
474
483
else :
475
484
return self .copy ()
476
485
486
+ channel_type_map = {
487
+ 'TextChannel' : ChannelType .text ,
488
+ 'VoiceChannel' : ChannelType .voice ,
489
+ 'StageChannel' : ChannelType .stage_voice ,
490
+ 'CategoryChannel' : ChannelType .category
491
+ }
492
+
477
493
class Option :
478
494
def __init__ (
479
495
self , input_type : Any , / , description : str = None , ** kwargs
480
496
) -> None :
481
497
self .name : Optional [str ] = kwargs .pop ("name" , None )
482
498
self .description = description or "No description provided"
499
+
500
+ self .channel_types = []
483
501
if not isinstance (input_type , SlashCommandOptionType ):
484
- input_type = SlashCommandOptionType .from_datatype (input_type )
485
- self .input_type = input_type
502
+ self .input_type = SlashCommandOptionType .from_datatype (input_type )
503
+ if self .input_type == SlashCommandOptionType .channel :
504
+ input_type = (input_type ,) if not isinstance (input_type , tuple ) else input_type
505
+ for i in input_type :
506
+ if i .__name__ == 'GuildChannel' :
507
+ continue
508
+ channel_type = channel_type_map [i .__name__ ].value
509
+ self .channel_types .append (channel_type )
510
+ else :
511
+ self .input_type = input_type
512
+ print (self .channel_types )
486
513
self .required : bool = kwargs .pop ("required" , True )
487
514
self .choices : List [OptionChoice ] = [
488
515
o if isinstance (o , OptionChoice ) else OptionChoice (o )
@@ -491,13 +518,18 @@ def __init__(
491
518
self .default = kwargs .pop ("default" , None )
492
519
493
520
def to_dict (self ) -> Dict :
494
- return {
521
+ as_dict = {
495
522
"name" : self .name ,
496
523
"description" : self .description ,
497
524
"type" : self .input_type .value ,
498
525
"required" : self .required ,
499
526
"choices" : [c .to_dict () for c in self .choices ],
500
527
}
528
+ if self .channel_types :
529
+ as_dict ["channel_types" ] = self .channel_types
530
+
531
+ return as_dict
532
+
501
533
502
534
def __repr__ (self ):
503
535
return f"<discord.app.commands.{ self .__class__ .__name__ } name={ self .name } >"
0 commit comments