Skip to content

Commit a98025a

Browse files
committed
first commit on voice fixes
1 parent 2b63af9 commit a98025a

File tree

13 files changed

+821
-191
lines changed

13 files changed

+821
-191
lines changed

discord/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from .template import *
7272
from .threads import *
7373
from .user import *
74-
from .voice_client import *
74+
from .voice import *
7575
from .webhook import *
7676
from .welcome_screen import *
7777
from .widget import *

discord/gateway.py

Lines changed: 115 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from __future__ import annotations
2727

2828
import asyncio
29+
from collections.abc import Callable
2930
import concurrent.futures
3031
import logging
3132
import struct
@@ -34,8 +35,8 @@
3435
import time
3536
import traceback
3637
import zlib
37-
from collections import deque, namedtuple
38-
from typing import TYPE_CHECKING
38+
from collections import deque
39+
from typing import TYPE_CHECKING, Any, NamedTuple
3940

4041
import aiohttp
4142

@@ -44,6 +45,12 @@
4445
from .enums import SpeakingState
4546
from .errors import ConnectionClosed, InvalidArgument
4647

48+
if TYPE_CHECKING:
49+
from typing_extensions import Self
50+
51+
from .client import Client
52+
from .state import ConnectionState
53+
4754
_log = logging.getLogger(__name__)
4855

4956
__all__ = (
@@ -68,26 +75,30 @@ class WebSocketClosure(Exception):
6875
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
6976

7077

71-
EventListener = namedtuple("EventListener", "predicate event result future")
78+
class EventListener(NamedTuple):
79+
predicate: Callable[[dict[str, Any]], bool]
80+
event: str
81+
result: Callable[[dict[str, Any]], Any] | None
82+
future: asyncio.Future[Any]
7283

7384

7485
class GatewayRatelimiter:
75-
def __init__(self, count=110, per=60.0):
86+
def __init__(self, count: int = 110, per: float = 60.0):
7687
# The default is 110 to give room for at least 10 heartbeats per minute
77-
self.max = count
78-
self.remaining = count
79-
self.window = 0.0
80-
self.per = per
81-
self.lock = asyncio.Lock()
82-
self.shard_id = None
83-
84-
def is_ratelimited(self):
88+
self.max: int = count
89+
self.remaining: int = count
90+
self.window: float = 0.0
91+
self.per: float = per
92+
self.lock: asyncio.Lock = asyncio.Lock()
93+
self.shard_id: int | None = None
94+
95+
def is_ratelimited(self) -> bool:
8596
current = time.time()
8697
if current > self.window + self.per:
8798
return False
8899
return self.remaining == 0
89100

90-
def get_delay(self):
101+
def get_delay(self) -> float:
91102
current = time.time()
92103

93104
if current > self.window + self.per:
@@ -105,7 +116,7 @@ def get_delay(self):
105116

106117
return 0.0
107118

108-
async def block(self):
119+
async def block(self) -> None:
109120
async with self.lock:
110121
delta = self.get_delay()
111122
if delta:
@@ -118,12 +129,16 @@ async def block(self):
118129

119130

120131
class KeepAliveHandler(threading.Thread):
121-
def __init__(self, *args, **kwargs):
122-
ws = kwargs.pop("ws", None)
123-
interval = kwargs.pop("interval", None)
124-
shard_id = kwargs.pop("shard_id", None)
132+
def __init__(
133+
self,
134+
*args: Any,
135+
ws: DiscordWebSocket,
136+
shard_id: int | None = None,
137+
interval: float | None = None,
138+
**kwargs: Any,
139+
) -> None:
125140
threading.Thread.__init__(self, *args, **kwargs)
126-
self.ws = ws
141+
self.ws: DiscordWebSocket = ws
127142
self._main_thread_id = ws.thread_id
128143
self.interval = interval
129144
self.daemon = True
@@ -292,52 +307,63 @@ class DiscordWebSocket:
292307
HEARTBEAT_ACK = 11
293308
GUILD_SYNC = 12
294309

295-
def __init__(self, socket, *, loop):
296-
self.socket = socket
297-
self.loop = loop
310+
if TYPE_CHECKING:
311+
token: str | None
312+
_connection: ConnectionState
313+
_discord_parsers: dict[str, Callable[..., Any]]
314+
call_hooks: Callable[..., Any]
315+
gateway: str
316+
_initial_identify: bool
317+
shard_id: int | None
318+
shard_count: int | None
319+
_max_heartbeat_timeout: float
320+
321+
def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
322+
self.socket: aiohttp.ClientWebSocketResponse = socket
323+
self.loop: asyncio.AbstractEventLoop = loop
298324

299325
# an empty dispatcher to prevent crashes
300-
self._dispatch = lambda *args: None
326+
self._dispatch: Callable[..., Any] = lambda *args: None
301327
# generic event listeners
302-
self._dispatch_listeners = []
328+
self._dispatch_listeners: list[EventListener] = []
303329
# the keep alive
304-
self._keep_alive = None
305-
self.thread_id = threading.get_ident()
330+
self._keep_alive: KeepAliveHandler | None = None
331+
self.thread_id: int = threading.get_ident()
306332

307333
# ws related stuff
308-
self.session_id = None
309-
self.sequence = None
310-
self.resume_gateway_url = None
311-
self._zlib = zlib.decompressobj()
312-
self._buffer = bytearray()
313-
self._close_code = None
314-
self._rate_limiter = GatewayRatelimiter()
334+
self.session_id: str | None = None
335+
self.sequence: int | None = None
336+
self.resume_gateway_url: str | None = None
337+
self._zlib: zlib._Decompress = zlib.decompressobj()
338+
self._buffer: bytearray = bytearray()
339+
self._close_code: int | None = None
340+
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
315341

316342
@property
317-
def open(self):
343+
def open(self) -> bool:
318344
return not self.socket.closed
319345

320-
def is_ratelimited(self):
346+
def is_ratelimited(self) -> bool:
321347
return self._rate_limiter.is_ratelimited()
322348

323-
def debug_log_receive(self, data, /):
349+
def debug_log_receive(self, data: dict[str, Any], /) -> None:
324350
self._dispatch("socket_raw_receive", data)
325351

326-
def log_receive(self, _, /):
352+
def log_receive(self, _: dict[str, Any], /) -> None:
327353
pass
328354

329355
@classmethod
330356
async def from_client(
331357
cls,
332-
client,
358+
client: Client,
333359
*,
334-
initial=False,
335-
gateway=None,
336-
shard_id=None,
337-
session=None,
338-
sequence=None,
339-
resume=False,
340-
):
360+
initial: bool = False,
361+
gateway: str | None = None,
362+
shard_id: int | None = None,
363+
session: str | None = None,
364+
sequence: int | None = None,
365+
resume: bool = False,
366+
) -> Self:
341367
"""Creates a main websocket for Discord from a :class:`Client`.
342368
343369
This is for internal use only.
@@ -379,7 +405,12 @@ async def from_client(
379405
await ws.resume()
380406
return ws
381407

382-
def wait_for(self, event, predicate, result=None):
408+
def wait_for(
409+
self,
410+
event: str,
411+
predicate: Callable[[dict[str, Any]], bool],
412+
result: Callable[[dict[str, Any]], Any] | None = None,
413+
) -> asyncio.Future[Any]:
383414
"""Waits for a DISPATCH'd event that meets the predicate.
384415
385416
Parameters
@@ -406,7 +437,7 @@ def wait_for(self, event, predicate, result=None):
406437
self._dispatch_listeners.append(entry)
407438
return future
408439

409-
async def identify(self):
440+
async def identify(self) -> None:
410441
"""Sends the IDENTIFY packet."""
411442
payload = {
412443
"op": self.IDENTIFY,
@@ -419,7 +450,6 @@ async def identify(self):
419450
},
420451
"compress": True,
421452
"large_threshold": 250,
422-
"v": 3,
423453
},
424454
}
425455

@@ -444,7 +474,7 @@ async def identify(self):
444474
await self.send_as_json(payload)
445475
_log.info("Shard ID %s has sent the IDENTIFY payload.", self.shard_id)
446476

447-
async def resume(self):
477+
async def resume(self) -> None:
448478
"""Sends the RESUME packet."""
449479
payload = {
450480
"op": self.RESUME,
@@ -458,7 +488,7 @@ async def resume(self):
458488
await self.send_as_json(payload)
459489
_log.info("Shard ID %s has sent the RESUME payload.", self.shard_id)
460490

461-
async def received_message(self, msg, /):
491+
async def received_message(self, msg: Any, /):
462492
if type(msg) is bytes:
463493
self._buffer.extend(msg)
464494

@@ -594,7 +624,7 @@ def latency(self) -> float:
594624
heartbeat = self._keep_alive
595625
return float("inf") if heartbeat is None else heartbeat.latency
596626

597-
def _can_handle_close(self):
627+
def _can_handle_close(self) -> bool:
598628
code = self._close_code or self.socket.close_code
599629
is_improper_close = self._close_code is None and self.socket.close_code == 1000
600630
return is_improper_close or code not in (
@@ -607,7 +637,7 @@ def _can_handle_close(self):
607637
4014,
608638
)
609639

610-
async def poll_event(self):
640+
async def poll_event(self) -> None:
611641
"""Polls for a DISPATCH event and handles the general gateway loop.
612642
613643
Raises
@@ -621,11 +651,12 @@ async def poll_event(self):
621651
await self.received_message(msg.data)
622652
elif msg.type is aiohttp.WSMsgType.BINARY:
623653
await self.received_message(msg.data)
654+
elif msg.type is aiohttp.WSMsgType.ERROR:
655+
_log.debug('Received an error %s', msg)
624656
elif msg.type in (
625657
aiohttp.WSMsgType.CLOSED,
626658
aiohttp.WSMsgType.CLOSING,
627659
aiohttp.WSMsgType.CLOSE,
628-
aiohttp.WSMsgType.ERROR,
629660
):
630661
_log.debug("Received %s", msg)
631662
raise WebSocketClosure
@@ -649,45 +680,51 @@ async def poll_event(self):
649680
self.socket, shard_id=self.shard_id, code=code
650681
) from None
651682

652-
async def debug_send(self, data, /):
683+
async def debug_send(self, data: str, /) -> None:
653684
await self._rate_limiter.block()
654685
self._dispatch("socket_raw_send", data)
655686
await self.socket.send_str(data)
656687

657-
async def send(self, data, /):
688+
async def send(self, data: str, /) -> None:
658689
await self._rate_limiter.block()
659690
await self.socket.send_str(data)
660691

661-
async def send_as_json(self, data):
692+
async def send_as_json(self, data: Any) -> None:
662693
try:
663694
await self.send(utils._to_json(data))
664695
except RuntimeError as exc:
665696
if not self._can_handle_close():
666697
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
667698

668-
async def send_heartbeat(self, data):
699+
async def send_heartbeat(self, data: Any) -> None:
669700
# This bypasses the rate limit handling code since it has a higher priority
670701
try:
671702
await self.socket.send_str(utils._to_json(data))
672703
except RuntimeError as exc:
673704
if not self._can_handle_close():
674705
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
675706

676-
async def change_presence(self, *, activity=None, status=None, since=0.0):
707+
async def change_presence(
708+
self,
709+
*,
710+
activity: BaseActivity | None = None,
711+
status: str | None = None,
712+
since: float = 0.0,
713+
) -> None:
677714
if activity is not None:
678715
if not isinstance(activity, BaseActivity):
679716
raise InvalidArgument("activity must derive from BaseActivity.")
680-
activity = [activity.to_dict()]
717+
activities = [activity.to_dict()]
681718
else:
682-
activity = []
719+
activities = []
683720

684721
if status == "idle":
685722
since = int(time.time() * 1000)
686723

687724
payload = {
688725
"op": self.PRESENCE,
689726
"d": {
690-
"activities": activity,
727+
"activities": activities,
691728
"afk": False,
692729
"since": since,
693730
"status": status,
@@ -699,8 +736,15 @@ async def change_presence(self, *, activity=None, status=None, since=0.0):
699736
await self.send(sent)
700737

701738
async def request_chunks(
702-
self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None
703-
):
739+
self,
740+
guild_id: int,
741+
query: str | None = None,
742+
*,
743+
limit: int,
744+
user_ids: list[int] | None = None,
745+
presences: bool = False,
746+
nonce: str | None = None,
747+
) -> None:
704748
payload = {
705749
"op": self.REQUEST_MEMBERS,
706750
"d": {"guild_id": guild_id, "presences": presences, "limit": limit},
@@ -717,7 +761,13 @@ async def request_chunks(
717761

718762
await self.send_as_json(payload)
719763

720-
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
764+
async def voice_state(
765+
self,
766+
guild_id: int,
767+
channel_id: int,
768+
self_mute: bool = False,
769+
self_deaf: bool = False,
770+
) -> None:
721771
payload = {
722772
"op": self.VOICE_STATE,
723773
"d": {
@@ -731,7 +781,7 @@ async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=Fal
731781
_log.debug("Updating our voice state to %s.", payload)
732782
await self.send_as_json(payload)
733783

734-
async def close(self, code=4000):
784+
async def close(self, code: int = 4000) -> None:
735785
if self._keep_alive:
736786
self._keep_alive.stop()
737787
self._keep_alive = None

0 commit comments

Comments
 (0)