Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 117 additions & 12 deletions discord/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
from .enums import SpeakingState
from .errors import ConnectionClosed

try:
import davey # type: ignore
except ImportError:
pass

_log = logging.getLogger(__name__)

__all__ = (
Expand Down Expand Up @@ -812,18 +817,30 @@ class DiscordVoiceWebSocket:
_max_heartbeat_timeout: float

# fmt: off
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
SESSION_DESCRIPTION = 4
SPEAKING = 5
HEARTBEAT_ACK = 6
RESUME = 7
HELLO = 8
RESUMED = 9
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
SESSION_DESCRIPTION = 4
SPEAKING = 5
HEARTBEAT_ACK = 6
RESUME = 7
HELLO = 8
RESUMED = 9
CLIENTS_CONNECT = 11
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
DAVE_PREPARE_TRANSITION = 21
DAVE_EXECUTE_TRANSITION = 22
DAVE_TRANSITION_READY = 23
DAVE_PREPARE_EPOCH = 24
MLS_EXTERNAL_SENDER = 25
MLS_KEY_PACKAGE = 26
MLS_PROPOSALS = 27
MLS_COMMIT_WELCOME = 28
MLS_ANNOUNCE_COMMIT_TRANSITION = 29
MLS_WELCOME = 30
MLS_INVALID_COMMIT_WELCOME = 31
# fmt: on

def __init__(
Expand All @@ -850,6 +867,10 @@ async def send_as_json(self, data: Any) -> None:
_log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils._to_json(data))

async def send_binary(self, opcode: int, data: bytes) -> None:
_log.debug('Sending voice websocket binary frame: opcode=%s size=%d', opcode, len(data))
await self.ws.send_bytes(bytes([opcode]) + data)

send_heartbeat = send_as_json

async def resume(self) -> None:
Expand All @@ -874,6 +895,7 @@ async def identify(self) -> None:
'user_id': str(state.user.id),
'session_id': state.session_id,
'token': state.token,
'max_dave_protocol_version': state.max_dave_protocol_version,
},
}
await self.send_as_json(payload)
Expand Down Expand Up @@ -943,6 +965,16 @@ async def speak(self, state: SpeakingState = SpeakingState.voice) -> None:

await self.send_as_json(payload)

async def send_transition_ready(self, transition_id: int):
payload = {
'op': DiscordVoiceWebSocket.DAVE_TRANSITION_READY,
'd': {
'transition_id': transition_id,
},
}

await self.send_as_json(payload)

async def received_message(self, msg: Dict[str, Any]) -> None:
_log.debug('Voice websocket frame received: %s', msg)
op = msg['op']
Expand All @@ -959,13 +991,84 @@ async def received_message(self, msg: Dict[str, Any]) -> None:
elif op == self.SESSION_DESCRIPTION:
self._connection.mode = data['mode']
await self.load_secret_key(data)
self._connection.dave_protocol_version = data['dave_protocol_version']
if data['dave_protocol_version'] > 0:
await self._connection.reinit_dave_session()
elif op == self.HELLO:
interval = data['heartbeat_interval'] / 1000.0
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
self._keep_alive.start()
elif self._connection.dave_session:
state = self._connection
if op == self.DAVE_PREPARE_TRANSITION:
_log.debug(
'Preparing for DAVE transition id %d for protocol version %d',
data['transition_id'],
data['protocol_version'],
)
state.dave_pending_transitions[data['transition_id']] = data['protocol_version']
if data['transition_id'] == 0:
await state._execute_transition(data['transition_id'])
else:
if data['protocol_version'] == 0 and state.dave_session:
state.dave_session.set_passthrough_mode(True, 120)

await self.send_transition_ready(data['transition_id'])
elif op == self.DAVE_EXECUTE_TRANSITION:
_log.debug('Executing DAVE transition id %d', data['transition_id'])
await state._execute_transition(data['transition_id'])
elif op == self.DAVE_PREPARE_EPOCH:
_log.debug('Preparing for DAVE epoch %d', data['epoch'])
# When the epoch ID is equal to 1, this message indicates that a new MLS group is to be created for the given protocol version.
if data['epoch'] == 1:
state.dave_protocol_version = data['protocol_version']
await state.reinit_dave_session()

await self._hook(self, msg)

async def recieved_binary_message(self, msg: bytes) -> None:
self.seq_ack = struct.unpack_from('>H', msg, 0)[0]
op = msg[2]
_log.debug('Voice websocket binary frame received: %d bytes; seq=%s op=%s', len(msg), self.seq_ack, op)
state = self._connection

if state.dave_session:
if op == self.MLS_EXTERNAL_SENDER:
state.dave_session.set_external_sender(msg[3:])
_log.debug('Set MLS external sender')
elif op == self.MLS_PROPOSALS:
optype = msg[3]
result = state.dave_session.process_proposals(
davey.ProposalsOperationType.append if optype == 0 else davey.ProposalsOperationType.revoke, msg[4:]
)
if isinstance(result, davey.CommitWelcome):
await self.send_binary(
DiscordVoiceWebSocket.MLS_COMMIT_WELCOME,
result.commit + result.welcome if result.welcome else result.commit,
)
_log.debug('MLS proposals processed')
elif op == self.MLS_ANNOUNCE_COMMIT_TRANSITION:
transition_id = struct.unpack_from('>H', msg, 3)[0]
try:
state.dave_session.process_commit(msg[5:])
if transition_id != 0:
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
await self.send_transition_ready(transition_id)
_log.debug('MLS commit processed for transition id %d', transition_id)
except Exception:
await state._recover_from_invalid_commit(transition_id)
elif op == self.MLS_WELCOME:
transition_id = struct.unpack_from('>H', msg, 3)[0]
try:
state.dave_session.process_welcome(msg[5:])
if transition_id != 0:
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
await self.send_transition_ready(transition_id)
_log.debug('MLS welcome processed for transition id %d', transition_id)
except Exception:
await state._recover_from_invalid_commit(transition_id)
pass

async def initial_connection(self, data: Dict[str, Any]) -> None:
state = self._connection
state.ssrc = data['ssrc']
Expand Down Expand Up @@ -1045,6 +1148,8 @@ async def poll_event(self) -> None:
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.BINARY:
await self.recieved_binary_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
Expand Down
18 changes: 16 additions & 2 deletions discord/voice_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,15 @@ def ws(self) -> DiscordVoiceWebSocket:
def timeout(self) -> float:
return self._connection.timeout

@property
def voice_privacy_code(self) -> Optional[str]:
""":class:`str`: Get the voice privacy code of this E2EE session's group.

A new privacy code is created and cached each time a new transition is executed.
This can be None if there is no active DAVE session happening.
"""
return self._connection.dave_session.voice_privacy_code if self._connection.dave_session else None

def checked_add(self, attr: str, value: int, limit: int) -> None:
val = getattr(self, attr)
if val + value > limit:
Expand Down Expand Up @@ -368,7 +377,12 @@ def is_connected(self) -> bool:

# audio related

def _get_voice_packet(self, data):
def _get_voice_packet(self, data: bytes):
packet = (
self._connection.dave_session.encrypt_opus(data)
if self._connection.dave_session and self._connection.can_encrypt
else data
)
header = bytearray(12)

# Formulate rtp header
Expand All @@ -379,7 +393,7 @@ def _get_voice_packet(self, data):
struct.pack_into('>I', header, 8, self.ssrc)

encrypt_packet = getattr(self, '_encrypt_' + self.mode)
return encrypt_packet(header, data)
return encrypt_packet(header, packet)

def _encrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data) -> bytes:
# Esentially the same as _lite
Expand Down
70 changes: 70 additions & 0 deletions discord/voice_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@
WebsocketHook = Optional[Callable[[DiscordVoiceWebSocket, Dict[str, Any]], Coroutine[Any, Any, Any]]]
SocketReaderCallback = Callable[[bytes], Any]

has_dave: bool

try:
import davey # type: ignore

has_dave = True
except ImportError:
has_dave = False

__all__ = ('VoiceConnectionState',)

Expand Down Expand Up @@ -208,6 +216,10 @@ def __init__(self, voice_client: VoiceClient, *, hook: Optional[WebsocketHook] =
self.mode: SupportedModes = MISSING
self.socket: socket.socket = MISSING
self.ws: DiscordVoiceWebSocket = MISSING
self.dave_session: Optional[davey.DaveSession] = None
self.dave_protocol_version: int = 0
self.dave_pending_transitions: Dict[int, int] = {}
self.dave_downgraded: bool = False

self._state: ConnectionFlowState = ConnectionFlowState.disconnected
self._expecting_disconnect: bool = False
Expand Down Expand Up @@ -252,6 +264,64 @@ def supported_modes(self) -> Tuple[SupportedModes, ...]:
def self_voice_state(self) -> Optional[VoiceState]:
return self.guild.me.voice

@property
def max_dave_protocol_version(self) -> int:
return davey.DAVE_PROTOCOL_VERSION if has_dave else 0

@property
def can_encrypt(self) -> bool:
return self.dave_protocol_version != 0 and self.dave_session != None and self.dave_session.ready

async def reinit_dave_session(self) -> None:
if self.dave_protocol_version > 0:
if not has_dave:
raise RuntimeError('davey library needed in order to use E2EE voice')
if self.dave_session is not None:
self.dave_session.reinit(self.dave_protocol_version, self.user.id, self.voice_client.channel.id)
else:
self.dave_session = davey.DaveSession(self.dave_protocol_version, self.user.id, self.voice_client.channel.id)

if self.dave_session is not None:
await self.voice_client.ws.send_binary(
DiscordVoiceWebSocket.MLS_KEY_PACKAGE, self.dave_session.get_serialized_key_package()
)
elif self.dave_session:
self.dave_session.reset()
self.dave_session.set_passthrough_mode(True, 10)
pass

async def _recover_from_invalid_commit(self, transition_id: int) -> None:
payload = {
'op': DiscordVoiceWebSocket.MLS_INVALID_COMMIT_WELCOME,
'd': {
'transition_id': transition_id,
},
}

await self.voice_client.ws.send_as_json(payload)
await self.reinit_dave_session()

async def _execute_transition(self, transition_id: int) -> None:
_log.debug('Executing transition id %d', transition_id)
if transition_id not in self.dave_pending_transitions:
_log.warning("Received execute transition, but we don't have a pending transition for id %d", transition_id)
return

old_version = self.dave_protocol_version
self.dave_protocol_version = self.dave_pending_transitions.pop(transition_id)

if old_version != self.dave_protocol_version and self.dave_protocol_version == 0:
self.dave_downgraded = True
_log.debug('DAVE Session downgraded')
elif transition_id > 0 and self.dave_downgraded:
self.dave_downgraded = False
if self.dave_session:
self.dave_session.set_passthrough_mode(True, 10)
_log.debug('DAVE Session upgraded')

# In the future, the session should be signaled too, but for now theres just v1
_log.debug('Transition id %d executed', transition_id)

async def voice_state_update(self, data: GuildVoiceStatePayload) -> None:
channel_id = data['channel_id']

Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ Documentation = "https://discordpy.readthedocs.io/en/latest/"
dependencies = { file = "requirements.txt" }

[project.optional-dependencies]
voice = ["PyNaCl>=1.5.0,<1.6"]
voice = [
"PyNaCl>=1.5.0,<1.6",
"davey==0.1.0"
]
docs = [
"sphinx==4.4.0",
"sphinxcontrib_trio==1.1.2",
Expand Down
Loading