26
26
from __future__ import annotations
27
27
28
28
import asyncio
29
+ from collections .abc import Callable
29
30
import concurrent .futures
30
31
import logging
31
32
import struct
34
35
import time
35
36
import traceback
36
37
import zlib
37
- from collections import deque , namedtuple
38
- from typing import TYPE_CHECKING
38
+ from collections import deque
39
+ from typing import TYPE_CHECKING , Any , NamedTuple
39
40
40
41
import aiohttp
41
42
44
45
from .enums import SpeakingState
45
46
from .errors import ConnectionClosed , InvalidArgument
46
47
48
+ if TYPE_CHECKING :
49
+ from typing_extensions import Self
50
+
51
+ from .client import Client
52
+ from .state import ConnectionState
53
+
47
54
_log = logging .getLogger (__name__ )
48
55
49
56
__all__ = (
@@ -68,26 +75,30 @@ class WebSocketClosure(Exception):
68
75
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
69
76
70
77
71
- EventListener = namedtuple ("EventListener" , "predicate event result future" )
78
+ class EventListener (NamedTuple ):
79
+ predicate : Callable [[dict [str , Any ]], bool ]
80
+ event : str
81
+ result : Callable [[dict [str , Any ]], Any ] | None
82
+ future : asyncio .Future [Any ]
72
83
73
84
74
85
class GatewayRatelimiter :
75
- def __init__ (self , count = 110 , per = 60.0 ):
86
+ def __init__ (self , count : int = 110 , per : float = 60.0 ):
76
87
# The default is 110 to give room for at least 10 heartbeats per minute
77
- self .max = count
78
- self .remaining = count
79
- self .window = 0.0
80
- self .per = per
81
- self .lock = asyncio .Lock ()
82
- self .shard_id = None
83
-
84
- def is_ratelimited (self ):
88
+ self .max : int = count
89
+ self .remaining : int = count
90
+ self .window : float = 0.0
91
+ self .per : float = per
92
+ self .lock : asyncio . Lock = asyncio .Lock ()
93
+ self .shard_id : int | None = None
94
+
95
+ def is_ratelimited (self ) -> bool :
85
96
current = time .time ()
86
97
if current > self .window + self .per :
87
98
return False
88
99
return self .remaining == 0
89
100
90
- def get_delay (self ):
101
+ def get_delay (self ) -> float :
91
102
current = time .time ()
92
103
93
104
if current > self .window + self .per :
@@ -105,7 +116,7 @@ def get_delay(self):
105
116
106
117
return 0.0
107
118
108
- async def block (self ):
119
+ async def block (self ) -> None :
109
120
async with self .lock :
110
121
delta = self .get_delay ()
111
122
if delta :
@@ -118,12 +129,16 @@ async def block(self):
118
129
119
130
120
131
class KeepAliveHandler (threading .Thread ):
121
- def __init__ (self , * args , ** kwargs ):
122
- ws = kwargs .pop ("ws" , None )
123
- interval = kwargs .pop ("interval" , None )
124
- shard_id = kwargs .pop ("shard_id" , None )
132
+ def __init__ (
133
+ self ,
134
+ * args : Any ,
135
+ ws : DiscordWebSocket ,
136
+ shard_id : int | None = None ,
137
+ interval : float | None = None ,
138
+ ** kwargs : Any ,
139
+ ) -> None :
125
140
threading .Thread .__init__ (self , * args , ** kwargs )
126
- self .ws = ws
141
+ self .ws : DiscordWebSocket = ws
127
142
self ._main_thread_id = ws .thread_id
128
143
self .interval = interval
129
144
self .daemon = True
@@ -292,52 +307,63 @@ class DiscordWebSocket:
292
307
HEARTBEAT_ACK = 11
293
308
GUILD_SYNC = 12
294
309
295
- def __init__ (self , socket , * , loop ):
296
- self .socket = socket
297
- self .loop = loop
310
+ if TYPE_CHECKING :
311
+ token : str | None
312
+ _connection : ConnectionState
313
+ _discord_parsers : dict [str , Callable [..., Any ]]
314
+ call_hooks : Callable [..., Any ]
315
+ gateway : str
316
+ _initial_identify : bool
317
+ shard_id : int | None
318
+ shard_count : int | None
319
+ _max_heartbeat_timeout : float
320
+
321
+ def __init__ (self , socket : aiohttp .ClientWebSocketResponse , * , loop : asyncio .AbstractEventLoop ) -> None :
322
+ self .socket : aiohttp .ClientWebSocketResponse = socket
323
+ self .loop : asyncio .AbstractEventLoop = loop
298
324
299
325
# an empty dispatcher to prevent crashes
300
- self ._dispatch = lambda * args : None
326
+ self ._dispatch : Callable [..., Any ] = lambda * args : None
301
327
# generic event listeners
302
- self ._dispatch_listeners = []
328
+ self ._dispatch_listeners : list [ EventListener ] = []
303
329
# the keep alive
304
- self ._keep_alive = None
305
- self .thread_id = threading .get_ident ()
330
+ self ._keep_alive : KeepAliveHandler | None = None
331
+ self .thread_id : int = threading .get_ident ()
306
332
307
333
# ws related stuff
308
- self .session_id = None
309
- self .sequence = None
310
- self .resume_gateway_url = None
311
- self ._zlib = zlib .decompressobj ()
312
- self ._buffer = bytearray ()
313
- self ._close_code = None
314
- self ._rate_limiter = GatewayRatelimiter ()
334
+ self .session_id : str | None = None
335
+ self .sequence : int | None = None
336
+ self .resume_gateway_url : str | None = None
337
+ self ._zlib : zlib . _Decompress = zlib .decompressobj ()
338
+ self ._buffer : bytearray = bytearray ()
339
+ self ._close_code : int | None = None
340
+ self ._rate_limiter : GatewayRatelimiter = GatewayRatelimiter ()
315
341
316
342
@property
317
- def open (self ):
343
+ def open (self ) -> bool :
318
344
return not self .socket .closed
319
345
320
- def is_ratelimited (self ):
346
+ def is_ratelimited (self ) -> bool :
321
347
return self ._rate_limiter .is_ratelimited ()
322
348
323
- def debug_log_receive (self , data , / ) :
349
+ def debug_log_receive (self , data : dict [ str , Any ], / ) -> None :
324
350
self ._dispatch ("socket_raw_receive" , data )
325
351
326
- def log_receive (self , _ , / ) :
352
+ def log_receive (self , _ : dict [ str , Any ], / ) -> None :
327
353
pass
328
354
329
355
@classmethod
330
356
async def from_client (
331
357
cls ,
332
- client ,
358
+ client : Client ,
333
359
* ,
334
- initial = False ,
335
- gateway = None ,
336
- shard_id = None ,
337
- session = None ,
338
- sequence = None ,
339
- resume = False ,
340
- ):
360
+ initial : bool = False ,
361
+ gateway : str | None = None ,
362
+ shard_id : int | None = None ,
363
+ session : str | None = None ,
364
+ sequence : int | None = None ,
365
+ resume : bool = False ,
366
+ ) -> Self :
341
367
"""Creates a main websocket for Discord from a :class:`Client`.
342
368
343
369
This is for internal use only.
@@ -379,7 +405,12 @@ async def from_client(
379
405
await ws .resume ()
380
406
return ws
381
407
382
- def wait_for (self , event , predicate , result = None ):
408
+ def wait_for (
409
+ self ,
410
+ event : str ,
411
+ predicate : Callable [[dict [str , Any ]], bool ],
412
+ result : Callable [[dict [str , Any ]], Any ] | None = None ,
413
+ ) -> asyncio .Future [Any ]:
383
414
"""Waits for a DISPATCH'd event that meets the predicate.
384
415
385
416
Parameters
@@ -406,7 +437,7 @@ def wait_for(self, event, predicate, result=None):
406
437
self ._dispatch_listeners .append (entry )
407
438
return future
408
439
409
- async def identify (self ):
440
+ async def identify (self ) -> None :
410
441
"""Sends the IDENTIFY packet."""
411
442
payload = {
412
443
"op" : self .IDENTIFY ,
@@ -419,7 +450,6 @@ async def identify(self):
419
450
},
420
451
"compress" : True ,
421
452
"large_threshold" : 250 ,
422
- "v" : 3 ,
423
453
},
424
454
}
425
455
@@ -444,7 +474,7 @@ async def identify(self):
444
474
await self .send_as_json (payload )
445
475
_log .info ("Shard ID %s has sent the IDENTIFY payload." , self .shard_id )
446
476
447
- async def resume (self ):
477
+ async def resume (self ) -> None :
448
478
"""Sends the RESUME packet."""
449
479
payload = {
450
480
"op" : self .RESUME ,
@@ -458,7 +488,7 @@ async def resume(self):
458
488
await self .send_as_json (payload )
459
489
_log .info ("Shard ID %s has sent the RESUME payload." , self .shard_id )
460
490
461
- async def received_message (self , msg , / ):
491
+ async def received_message (self , msg : Any , / ):
462
492
if type (msg ) is bytes :
463
493
self ._buffer .extend (msg )
464
494
@@ -594,7 +624,7 @@ def latency(self) -> float:
594
624
heartbeat = self ._keep_alive
595
625
return float ("inf" ) if heartbeat is None else heartbeat .latency
596
626
597
- def _can_handle_close (self ):
627
+ def _can_handle_close (self ) -> bool :
598
628
code = self ._close_code or self .socket .close_code
599
629
is_improper_close = self ._close_code is None and self .socket .close_code == 1000
600
630
return is_improper_close or code not in (
@@ -607,7 +637,7 @@ def _can_handle_close(self):
607
637
4014 ,
608
638
)
609
639
610
- async def poll_event (self ):
640
+ async def poll_event (self ) -> None :
611
641
"""Polls for a DISPATCH event and handles the general gateway loop.
612
642
613
643
Raises
@@ -621,11 +651,12 @@ async def poll_event(self):
621
651
await self .received_message (msg .data )
622
652
elif msg .type is aiohttp .WSMsgType .BINARY :
623
653
await self .received_message (msg .data )
654
+ elif msg .type is aiohttp .WSMsgType .ERROR :
655
+ _log .debug ('Received an error %s' , msg )
624
656
elif msg .type in (
625
657
aiohttp .WSMsgType .CLOSED ,
626
658
aiohttp .WSMsgType .CLOSING ,
627
659
aiohttp .WSMsgType .CLOSE ,
628
- aiohttp .WSMsgType .ERROR ,
629
660
):
630
661
_log .debug ("Received %s" , msg )
631
662
raise WebSocketClosure
@@ -649,45 +680,51 @@ async def poll_event(self):
649
680
self .socket , shard_id = self .shard_id , code = code
650
681
) from None
651
682
652
- async def debug_send (self , data , / ):
683
+ async def debug_send (self , data : str , / ) -> None :
653
684
await self ._rate_limiter .block ()
654
685
self ._dispatch ("socket_raw_send" , data )
655
686
await self .socket .send_str (data )
656
687
657
- async def send (self , data , / ):
688
+ async def send (self , data : str , / ) -> None :
658
689
await self ._rate_limiter .block ()
659
690
await self .socket .send_str (data )
660
691
661
- async def send_as_json (self , data ) :
692
+ async def send_as_json (self , data : Any ) -> None :
662
693
try :
663
694
await self .send (utils ._to_json (data ))
664
695
except RuntimeError as exc :
665
696
if not self ._can_handle_close ():
666
697
raise ConnectionClosed (self .socket , shard_id = self .shard_id ) from exc
667
698
668
- async def send_heartbeat (self , data ) :
699
+ async def send_heartbeat (self , data : Any ) -> None :
669
700
# This bypasses the rate limit handling code since it has a higher priority
670
701
try :
671
702
await self .socket .send_str (utils ._to_json (data ))
672
703
except RuntimeError as exc :
673
704
if not self ._can_handle_close ():
674
705
raise ConnectionClosed (self .socket , shard_id = self .shard_id ) from exc
675
706
676
- async def change_presence (self , * , activity = None , status = None , since = 0.0 ):
707
+ async def change_presence (
708
+ self ,
709
+ * ,
710
+ activity : BaseActivity | None = None ,
711
+ status : str | None = None ,
712
+ since : float = 0.0 ,
713
+ ) -> None :
677
714
if activity is not None :
678
715
if not isinstance (activity , BaseActivity ):
679
716
raise InvalidArgument ("activity must derive from BaseActivity." )
680
- activity = [activity .to_dict ()]
717
+ activities = [activity .to_dict ()]
681
718
else :
682
- activity = []
719
+ activities = []
683
720
684
721
if status == "idle" :
685
722
since = int (time .time () * 1000 )
686
723
687
724
payload = {
688
725
"op" : self .PRESENCE ,
689
726
"d" : {
690
- "activities" : activity ,
727
+ "activities" : activities ,
691
728
"afk" : False ,
692
729
"since" : since ,
693
730
"status" : status ,
@@ -699,8 +736,15 @@ async def change_presence(self, *, activity=None, status=None, since=0.0):
699
736
await self .send (sent )
700
737
701
738
async def request_chunks (
702
- self , guild_id , query = None , * , limit , user_ids = None , presences = False , nonce = None
703
- ):
739
+ self ,
740
+ guild_id : int ,
741
+ query : str | None = None ,
742
+ * ,
743
+ limit : int ,
744
+ user_ids : list [int ] | None = None ,
745
+ presences : bool = False ,
746
+ nonce : str | None = None ,
747
+ ) -> None :
704
748
payload = {
705
749
"op" : self .REQUEST_MEMBERS ,
706
750
"d" : {"guild_id" : guild_id , "presences" : presences , "limit" : limit },
@@ -717,7 +761,13 @@ async def request_chunks(
717
761
718
762
await self .send_as_json (payload )
719
763
720
- async def voice_state (self , guild_id , channel_id , self_mute = False , self_deaf = False ):
764
+ async def voice_state (
765
+ self ,
766
+ guild_id : int ,
767
+ channel_id : int ,
768
+ self_mute : bool = False ,
769
+ self_deaf : bool = False ,
770
+ ) -> None :
721
771
payload = {
722
772
"op" : self .VOICE_STATE ,
723
773
"d" : {
@@ -731,7 +781,7 @@ async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=Fal
731
781
_log .debug ("Updating our voice state to %s." , payload )
732
782
await self .send_as_json (payload )
733
783
734
- async def close (self , code = 4000 ):
784
+ async def close (self , code : int = 4000 ) -> None :
735
785
if self ._keep_alive :
736
786
self ._keep_alive .stop ()
737
787
self ._keep_alive = None
0 commit comments