1+ from inspect import getmembers
12from typing import Dict , List , Optional , Union
23
34from lavalink import Client as LavalinkClient
4- from lavalink import DefaultPlayer
55
6- from interactions import Client , Snowflake
6+ from interactions import Client , LibraryException , Snowflake
77
88from .models import VoiceState
9+ from .player import Player
910from .websocket import VoiceWebSocketClient
1011
11- __all__ = ["VoiceClient" ]
12+ __all__ = ["VoiceClient" , "listener" ]
1213
1314
1415class VoiceClient (Client ):
1516 def __init__ (self , token : str , ** kwargs ):
1617 super ().__init__ (token , ** kwargs )
1718
1819 self ._websocket = VoiceWebSocketClient (token , self ._intents )
19- self .lavalink_client = LavalinkClient (int (self .me .id ))
20+ self .lavalink_client = LavalinkClient (int (self .me .id ), player = Player )
2021
2122 self ._websocket ._dispatch .register (
2223 self .__raw_voice_state_update , "on_raw_voice_state_update"
@@ -25,7 +26,7 @@ def __init__(self, token: str, **kwargs):
2526 self .__raw_voice_server_update , "on_raw_voice_server_update"
2627 )
2728
28- self ._websocket ._bot_var = self
29+ self ._websocket ._http . _bot_var = self
2930 self ._http ._bot_var = self
3031
3132 async def __raw_voice_state_update (self , data : dict ):
@@ -42,7 +43,7 @@ async def connect(
4243 channel_id : Union [Snowflake , int , str ],
4344 self_deaf : bool = False ,
4445 self_mute : bool = False ,
45- ) -> DefaultPlayer :
46+ ) -> Player :
4647 """
4748 Connects to voice channel and creates player.
4849
@@ -55,26 +56,35 @@ async def connect(
5556 :param self_mute: Whether bot is self muted
5657 :type self_mute: bool
5758 :return: Created guild player.
58- :rtype: DefaultPlayer
59+ :rtype: Player
5960 """
61+ # Discord will fire INVALID_SESSION if channel_id is None
62+ if guild_id is None :
63+ raise LibraryException (message = "Missed requirement argument: guild_id" )
64+ if channel_id is None :
65+ raise LibraryException (message = "Missed requirement argument: channel_id" )
66+
6067 await self ._websocket .connect_voice_channel (guild_id , channel_id , self_deaf , self_mute )
6168 player = self .lavalink_client .player_manager .get (int (guild_id ))
6269 if player is None :
6370 player = self .lavalink_client .player_manager .create (int (guild_id ))
6471 return player
6572
6673 async def disconnect (self , guild_id : Union [Snowflake , int ]):
74+ if guild_id is None :
75+ raise LibraryException (message = "Missed requirement argument: guild_id" )
76+
6777 await self ._websocket .disconnect_voice_channel (int (guild_id ))
6878 await self .lavalink_client .player_manager .destroy (int (guild_id ))
6979
70- def get_player (self , guild_id : Union [Snowflake , int ]) -> DefaultPlayer :
80+ def get_player (self , guild_id : Union [Snowflake , int ]) -> Player :
7181 """
7282 Returns current player in guild.
7383
7484 :param guild_id: The guild id
7585 :type guild_id: Union[Snowflake, int]
7686 :return: Guild player
77- :rtype: DefaultPlayer
87+ :rtype: Player
7888 """
7989 return self .lavalink_client .player_manager .get (int (guild_id ))
8090
@@ -96,7 +106,7 @@ def get_user_voice_state(self, user_id: Union[Snowflake, int]) -> Optional[Voice
96106 _user_id = Snowflake (user_id ) if isinstance (user_id , int ) else user_id
97107 return self ._http .cache [VoiceState ].get (_user_id )
98108
99- def get_guild_voice_states (self , guild_id : Union [Snowflake , int ]):
109+ def get_guild_voice_states (self , guild_id : Union [Snowflake , int ]) -> Optional [ List [ VoiceState ]] :
100110 """
101111 Returns guild voice states.
102112
@@ -131,3 +141,28 @@ def get_channel_voice_states(
131141 for voice_state in self .voice_states .values ()
132142 if voice_state .channel_id == _channel_id
133143 ]
144+
145+ def __register_lavalink_listeners (self ):
146+ for extension in self ._extensions .values ():
147+ for name , func in getmembers (extension ):
148+ if hasattr (func , "__lavalink__" ):
149+ name = func .__lavalink__ [3 :]
150+ event_name = "" .join (word .capitalize () for word in name .split ("_" )) + "Event"
151+ if event_name not in self .lavalink_client ._event_hooks :
152+ self .lavalink_client ._event_hooks [event_name ] = []
153+ self .lavalink_client ._event_hooks [event_name ].append (func )
154+
155+ async def _ready (self ) -> None :
156+ self .__register_lavalink_listeners ()
157+ await super ()._ready ()
158+
159+
160+ def listener (func = None , * , name : str = None ):
161+ def wrapper (func ):
162+ _name = name or func .__name__
163+ func .__lavalink__ = _name
164+ return func
165+
166+ if func is not None :
167+ return wrapper (func )
168+ return wrapper
0 commit comments