2626from __future__ import annotations
2727
2828import asyncio
29+ from collections .abc import Callable
2930import concurrent .futures
3031import logging
3132import struct
3435import time
3536import traceback
3637import zlib
37- from collections import deque , namedtuple
38- from typing import TYPE_CHECKING
38+ from collections import deque
39+ from typing import TYPE_CHECKING , Any , NamedTuple
3940
4041import aiohttp
4142
4445from .enums import SpeakingState
4546from .errors import ConnectionClosed , InvalidArgument
4647
48+ if TYPE_CHECKING :
49+ from typing_extensions import Self
50+
51+ from .client import Client
52+ from .state import ConnectionState
53+
4754_log = logging .getLogger (__name__ )
4855
4956__all__ = (
@@ -68,26 +75,30 @@ class WebSocketClosure(Exception):
6875 """An exception to make up for the fact that aiohttp doesn't signal closure."""
6976
7077
71- EventListener = namedtuple ("EventListener" , "predicate event result future" )
78+ class EventListener (NamedTuple ):
79+ predicate : Callable [[dict [str , Any ]], bool ]
80+ event : str
81+ result : Callable [[dict [str , Any ]], Any ] | None
82+ future : asyncio .Future [Any ]
7283
7384
7485class GatewayRatelimiter :
75- def __init__ (self , count = 110 , per = 60.0 ):
86+ def __init__ (self , count : int = 110 , per : float = 60.0 ):
7687 # The default is 110 to give room for at least 10 heartbeats per minute
77- self .max = count
78- self .remaining = count
79- self .window = 0.0
80- self .per = per
81- self .lock = asyncio .Lock ()
82- self .shard_id = None
83-
84- def is_ratelimited (self ):
88+ self .max : int = count
89+ self .remaining : int = count
90+ self .window : float = 0.0
91+ self .per : float = per
92+ self .lock : asyncio . Lock = asyncio .Lock ()
93+ self .shard_id : int | None = None
94+
95+ def is_ratelimited (self ) -> bool :
8596 current = time .time ()
8697 if current > self .window + self .per :
8798 return False
8899 return self .remaining == 0
89100
90- def get_delay (self ):
101+ def get_delay (self ) -> float :
91102 current = time .time ()
92103
93104 if current > self .window + self .per :
@@ -105,7 +116,7 @@ def get_delay(self):
105116
106117 return 0.0
107118
108- async def block (self ):
119+ async def block (self ) -> None :
109120 async with self .lock :
110121 delta = self .get_delay ()
111122 if delta :
@@ -118,12 +129,16 @@ async def block(self):
118129
119130
120131class KeepAliveHandler (threading .Thread ):
121- def __init__ (self , * args , ** kwargs ):
122- ws = kwargs .pop ("ws" , None )
123- interval = kwargs .pop ("interval" , None )
124- shard_id = kwargs .pop ("shard_id" , None )
132+ def __init__ (
133+ self ,
134+ * args : Any ,
135+ ws : DiscordWebSocket ,
136+ shard_id : int | None = None ,
137+ interval : float | None = None ,
138+ ** kwargs : Any ,
139+ ) -> None :
125140 threading .Thread .__init__ (self , * args , ** kwargs )
126- self .ws = ws
141+ self .ws : DiscordWebSocket = ws
127142 self ._main_thread_id = ws .thread_id
128143 self .interval = interval
129144 self .daemon = True
@@ -292,52 +307,63 @@ class DiscordWebSocket:
292307 HEARTBEAT_ACK = 11
293308 GUILD_SYNC = 12
294309
295- def __init__ (self , socket , * , loop ):
296- self .socket = socket
297- self .loop = loop
310+ if TYPE_CHECKING :
311+ token : str | None
312+ _connection : ConnectionState
313+ _discord_parsers : dict [str , Callable [..., Any ]]
314+ call_hooks : Callable [..., Any ]
315+ gateway : str
316+ _initial_identify : bool
317+ shard_id : int | None
318+ shard_count : int | None
319+ _max_heartbeat_timeout : float
320+
321+ def __init__ (self , socket : aiohttp .ClientWebSocketResponse , * , loop : asyncio .AbstractEventLoop ) -> None :
322+ self .socket : aiohttp .ClientWebSocketResponse = socket
323+ self .loop : asyncio .AbstractEventLoop = loop
298324
299325 # an empty dispatcher to prevent crashes
300- self ._dispatch = lambda * args : None
326+ self ._dispatch : Callable [..., Any ] = lambda * args : None
301327 # generic event listeners
302- self ._dispatch_listeners = []
328+ self ._dispatch_listeners : list [ EventListener ] = []
303329 # the keep alive
304- self ._keep_alive = None
305- self .thread_id = threading .get_ident ()
330+ self ._keep_alive : KeepAliveHandler | None = None
331+ self .thread_id : int = threading .get_ident ()
306332
307333 # ws related stuff
308- self .session_id = None
309- self .sequence = None
310- self .resume_gateway_url = None
311- self ._zlib = zlib .decompressobj ()
312- self ._buffer = bytearray ()
313- self ._close_code = None
314- self ._rate_limiter = GatewayRatelimiter ()
334+ self .session_id : str | None = None
335+ self .sequence : int | None = None
336+ self .resume_gateway_url : str | None = None
337+ self ._zlib : zlib . _Decompress = zlib .decompressobj ()
338+ self ._buffer : bytearray = bytearray ()
339+ self ._close_code : int | None = None
340+ self ._rate_limiter : GatewayRatelimiter = GatewayRatelimiter ()
315341
316342 @property
317- def open (self ):
343+ def open (self ) -> bool :
318344 return not self .socket .closed
319345
320- def is_ratelimited (self ):
346+ def is_ratelimited (self ) -> bool :
321347 return self ._rate_limiter .is_ratelimited ()
322348
323- def debug_log_receive (self , data , / ) :
349+ def debug_log_receive (self , data : dict [ str , Any ], / ) -> None :
324350 self ._dispatch ("socket_raw_receive" , data )
325351
326- def log_receive (self , _ , / ) :
352+ def log_receive (self , _ : dict [ str , Any ], / ) -> None :
327353 pass
328354
329355 @classmethod
330356 async def from_client (
331357 cls ,
332- client ,
358+ client : Client ,
333359 * ,
334- initial = False ,
335- gateway = None ,
336- shard_id = None ,
337- session = None ,
338- sequence = None ,
339- resume = False ,
340- ):
360+ initial : bool = False ,
361+ gateway : str | None = None ,
362+ shard_id : int | None = None ,
363+ session : str | None = None ,
364+ sequence : int | None = None ,
365+ resume : bool = False ,
366+ ) -> Self :
341367 """Creates a main websocket for Discord from a :class:`Client`.
342368
343369 This is for internal use only.
@@ -379,7 +405,12 @@ async def from_client(
379405 await ws .resume ()
380406 return ws
381407
382- def wait_for (self , event , predicate , result = None ):
408+ def wait_for (
409+ self ,
410+ event : str ,
411+ predicate : Callable [[dict [str , Any ]], bool ],
412+ result : Callable [[dict [str , Any ]], Any ] | None = None ,
413+ ) -> asyncio .Future [Any ]:
383414 """Waits for a DISPATCH'd event that meets the predicate.
384415
385416 Parameters
@@ -406,7 +437,7 @@ def wait_for(self, event, predicate, result=None):
406437 self ._dispatch_listeners .append (entry )
407438 return future
408439
409- async def identify (self ):
440+ async def identify (self ) -> None :
410441 """Sends the IDENTIFY packet."""
411442 payload = {
412443 "op" : self .IDENTIFY ,
@@ -419,7 +450,6 @@ async def identify(self):
419450 },
420451 "compress" : True ,
421452 "large_threshold" : 250 ,
422- "v" : 3 ,
423453 },
424454 }
425455
@@ -444,7 +474,7 @@ async def identify(self):
444474 await self .send_as_json (payload )
445475 _log .info ("Shard ID %s has sent the IDENTIFY payload." , self .shard_id )
446476
447- async def resume (self ):
477+ async def resume (self ) -> None :
448478 """Sends the RESUME packet."""
449479 payload = {
450480 "op" : self .RESUME ,
@@ -458,7 +488,7 @@ async def resume(self):
458488 await self .send_as_json (payload )
459489 _log .info ("Shard ID %s has sent the RESUME payload." , self .shard_id )
460490
461- async def received_message (self , msg , / ):
491+ async def received_message (self , msg : Any , / ):
462492 if type (msg ) is bytes :
463493 self ._buffer .extend (msg )
464494
@@ -594,7 +624,7 @@ def latency(self) -> float:
594624 heartbeat = self ._keep_alive
595625 return float ("inf" ) if heartbeat is None else heartbeat .latency
596626
597- def _can_handle_close (self ):
627+ def _can_handle_close (self ) -> bool :
598628 code = self ._close_code or self .socket .close_code
599629 is_improper_close = self ._close_code is None and self .socket .close_code == 1000
600630 return is_improper_close or code not in (
@@ -607,7 +637,7 @@ def _can_handle_close(self):
607637 4014 ,
608638 )
609639
610- async def poll_event (self ):
640+ async def poll_event (self ) -> None :
611641 """Polls for a DISPATCH event and handles the general gateway loop.
612642
613643 Raises
@@ -621,11 +651,12 @@ async def poll_event(self):
621651 await self .received_message (msg .data )
622652 elif msg .type is aiohttp .WSMsgType .BINARY :
623653 await self .received_message (msg .data )
654+ elif msg .type is aiohttp .WSMsgType .ERROR :
655+ _log .debug ('Received an error %s' , msg )
624656 elif msg .type in (
625657 aiohttp .WSMsgType .CLOSED ,
626658 aiohttp .WSMsgType .CLOSING ,
627659 aiohttp .WSMsgType .CLOSE ,
628- aiohttp .WSMsgType .ERROR ,
629660 ):
630661 _log .debug ("Received %s" , msg )
631662 raise WebSocketClosure
@@ -649,45 +680,51 @@ async def poll_event(self):
649680 self .socket , shard_id = self .shard_id , code = code
650681 ) from None
651682
652- async def debug_send (self , data , / ):
683+ async def debug_send (self , data : str , / ) -> None :
653684 await self ._rate_limiter .block ()
654685 self ._dispatch ("socket_raw_send" , data )
655686 await self .socket .send_str (data )
656687
657- async def send (self , data , / ):
688+ async def send (self , data : str , / ) -> None :
658689 await self ._rate_limiter .block ()
659690 await self .socket .send_str (data )
660691
661- async def send_as_json (self , data ) :
692+ async def send_as_json (self , data : Any ) -> None :
662693 try :
663694 await self .send (utils ._to_json (data ))
664695 except RuntimeError as exc :
665696 if not self ._can_handle_close ():
666697 raise ConnectionClosed (self .socket , shard_id = self .shard_id ) from exc
667698
668- async def send_heartbeat (self , data ) :
699+ async def send_heartbeat (self , data : Any ) -> None :
669700 # This bypasses the rate limit handling code since it has a higher priority
670701 try :
671702 await self .socket .send_str (utils ._to_json (data ))
672703 except RuntimeError as exc :
673704 if not self ._can_handle_close ():
674705 raise ConnectionClosed (self .socket , shard_id = self .shard_id ) from exc
675706
676- async def change_presence (self , * , activity = None , status = None , since = 0.0 ):
707+ async def change_presence (
708+ self ,
709+ * ,
710+ activity : BaseActivity | None = None ,
711+ status : str | None = None ,
712+ since : float = 0.0 ,
713+ ) -> None :
677714 if activity is not None :
678715 if not isinstance (activity , BaseActivity ):
679716 raise InvalidArgument ("activity must derive from BaseActivity." )
680- activity = [activity .to_dict ()]
717+ activities = [activity .to_dict ()]
681718 else :
682- activity = []
719+ activities = []
683720
684721 if status == "idle" :
685722 since = int (time .time () * 1000 )
686723
687724 payload = {
688725 "op" : self .PRESENCE ,
689726 "d" : {
690- "activities" : activity ,
727+ "activities" : activities ,
691728 "afk" : False ,
692729 "since" : since ,
693730 "status" : status ,
@@ -699,8 +736,15 @@ async def change_presence(self, *, activity=None, status=None, since=0.0):
699736 await self .send (sent )
700737
701738 async def request_chunks (
702- self , guild_id , query = None , * , limit , user_ids = None , presences = False , nonce = None
703- ):
739+ self ,
740+ guild_id : int ,
741+ query : str | None = None ,
742+ * ,
743+ limit : int ,
744+ user_ids : list [int ] | None = None ,
745+ presences : bool = False ,
746+ nonce : str | None = None ,
747+ ) -> None :
704748 payload = {
705749 "op" : self .REQUEST_MEMBERS ,
706750 "d" : {"guild_id" : guild_id , "presences" : presences , "limit" : limit },
@@ -717,7 +761,13 @@ async def request_chunks(
717761
718762 await self .send_as_json (payload )
719763
720- async def voice_state (self , guild_id , channel_id , self_mute = False , self_deaf = False ):
764+ async def voice_state (
765+ self ,
766+ guild_id : int ,
767+ channel_id : int ,
768+ self_mute : bool = False ,
769+ self_deaf : bool = False ,
770+ ) -> None :
721771 payload = {
722772 "op" : self .VOICE_STATE ,
723773 "d" : {
@@ -731,7 +781,7 @@ async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=Fal
731781 _log .debug ("Updating our voice state to %s." , payload )
732782 await self .send_as_json (payload )
733783
734- async def close (self , code = 4000 ):
784+ async def close (self , code : int = 4000 ) -> None :
735785 if self ._keep_alive :
736786 self ._keep_alive .stop ()
737787 self ._keep_alive = None
0 commit comments