Skip to content

Commit 13378e6

Browse files
committed
Add dialer and listener to webrtc-direct
1 parent 0784f1c commit 13378e6

20 files changed

+887
-373
lines changed

libp2p/transport/webrtc/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sys
99
from .private_to_private.transport import WebRTCTransport
1010
from .private_to_public.transport import WebRTCDirectTransport
11-
from ..constants import (
11+
from .constants import (
1212
DEFAULT_ICE_SERVERS,
1313
SIGNALING_PROTOCOL,
1414
MUXER_PROTOCOL,
File renamed without changes.

libp2p/transport/webrtc/listener.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
logging.basicConfig(level=logging.INFO)
4242
SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0")
4343

44-
4544
class WebRTCListener(IListener):
4645
"""
4746
WebRTC Listener Implementation.

libp2p/transport/webrtc/multiaddr_codecs.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ def webrtc_direct_decode(b: ByteString) -> str:
3333
return ""
3434

3535

36-
def certhash_encode(s: str) -> ByteString:
37-
"""Encode certificate hash component."""
36+
def certhash_decode(s: str) -> Tuple[int, bytes]:
3837
if not s:
39-
return b""
38+
raise ValueError("Empty certhash string.")
4039

4140
# Remove multibase prefix if present
4241
if s.startswith("uEi"):
@@ -46,16 +45,29 @@ def certhash_encode(s: str) -> ByteString:
4645

4746
# Decode base64url encoded hash
4847
try:
49-
# Ensure s is bytes for base64 decoding
50-
s_bytes = s.encode("ascii") if isinstance(s, str) else s
48+
s_bytes = s.encode("ascii")
5149
# Add padding if needed
5250
padding = 4 - (len(s_bytes) % 4)
5351
if padding != 4:
5452
s_bytes += b"=" * padding
55-
return base64.urlsafe_b64decode(s_bytes)
56-
except Exception:
57-
# Fallback to raw bytes
58-
return s.encode("utf-8")
53+
raw_bytes = base64.urlsafe_b64decode(s_bytes)
54+
except Exception as e:
55+
raise ValueError("Invalid base64url certhash") from e
56+
57+
if len(raw_bytes) < 2:
58+
raise ValueError("Decoded certhash is too short to contain multihash header")
59+
60+
# Multihash format: <code><length><digest>
61+
code = raw_bytes[0]
62+
length = raw_bytes[1]
63+
digest = raw_bytes[2:]
64+
65+
if len(digest) != length:
66+
raise ValueError(
67+
f"Digest length mismatch: expected {length}, got {len(digest)}"
68+
)
69+
70+
return code, digest
5971

6072

6173
def certhash_decode(b: ByteString) -> str:
@@ -73,6 +85,6 @@ def certhash_decode(b: ByteString) -> str:
7385
"webrtc_decode",
7486
"webrtc_direct_encode",
7587
"webrtc_direct_decode",
76-
"certhash_encode",
77-
"certhash_decode",
88+
# "certhash_encode",
89+
# "certhash_decode",
7890
]

libp2p/transport/webrtc/private_to_private/initiate_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from ..async_bridge import TrioSafeWebRTCOperations
1919
from ..connection import WebRTCRawConnection
20-
from ...constants import (
20+
from ..constants import (
2121
DEFAULT_DIAL_TIMEOUT,
2222
SIGNALING_PROTOCOL,
2323
SDPHandshakeError,

libp2p/transport/webrtc/private_to_private/listener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from libp2p.relay.circuit_v2.config import RelayConfig
1616

17-
from ...constants import (
17+
from ..constants import (
1818
DEFAULT_DIAL_TIMEOUT,
1919
DEFAULT_ICE_SERVERS,
2020
SIGNALING_PROTOCOL,

libp2p/transport/webrtc/private_to_private/signaling_stream_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from libp2p.peer.id import ID
1616

1717
from ..connection import WebRTCRawConnection
18-
from ...constants import WebRTCError
18+
from ..constants import WebRTCError
1919
from .pb import Message
2020

2121
logger = logging.getLogger("webrtc.private.signaling_stream_handler")

libp2p/transport/webrtc/private_to_private/transport.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from libp2p.host.basic_host import IHost
1818
from libp2p.transport.exceptions import OpenConnectionError
1919

20-
from ...constants import (
20+
from ..constants import (
2121
DEFAULT_DIAL_TIMEOUT,
2222
DEFAULT_ICE_SERVERS,
2323
SIGNALING_PROTOCOL,
2424
WebRTCError,
2525
)
26-
from ..util import (
26+
from ..private_to_public.util import (
2727
pick_random_ice_servers,
2828
)
2929
from .initiate_connection import initiate_connection
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import trio
2+
from aiortc import RTCDataChannel, RTCSessionDescription
3+
from .direct_rtc_connection import DirectPeerConnection
4+
from libp2p.transport.webrtc.private_to_public.util import (
5+
SDP,
6+
generate_noise_prologue,
7+
fingerprint_to_multiaddr,
8+
)
9+
from trio_asyncio import aio_as_trio
10+
from libp2p.transport.webrtc.noise_handshake import (
11+
generate_noise_prologue,
12+
NoiseEncrypter,
13+
)
14+
from libp2p.transport.webrtc.connection import WebRTCMultiaddrConnection
15+
from libp2p.transport.webrtc.muxer import DataChannelMuxerFactory
16+
from libp2p.transport.webrtc.constants import WEBRTC_CONNECTION_STATES
17+
import logging
18+
19+
logger = logging.getLogger("libp2p.transport.webrtc.private_to_public")
20+
21+
async def connect(
22+
peer_connection: DirectPeerConnection,
23+
ufrag: str,
24+
role: str
25+
):
26+
"""
27+
Establish a WebRTC-Direct connection, perform the noise handshake, and return the upgraded connection.
28+
"""
29+
30+
# Create data channel for noise handshake (negotiated, id=0)
31+
handshake_channel: RTCDataChannel = peer_connection.peer_connection.createDataChannel(
32+
"", negotiated=True, id=0
33+
)
34+
35+
try:
36+
if role == "client":
37+
logger.debug("client creating local offer")
38+
offer = await peer_connection.createOffer()
39+
logger.debug("client created local offer %s", offer.sdp)
40+
munged_offer = SDP.munge_offer(offer, ufrag)
41+
logger.debug("client setting local offer %s", munged_offer.sdp)
42+
await aio_as_trio(peer_connection.setLocalDescription(munged_offer))
43+
44+
answer_sdp = SDP.server_answer_from_multiaddr(remote_addr, ufrag)
45+
logger.debug("client setting server description %s", answer_sdp.sdp)
46+
await aio_as_trio(peer_connection.setRemoteDescription(answer_sdp))
47+
else:
48+
offer_sdp = SDP.client_offer_from_multiaddr(remote_addr, ufrag)
49+
logger.debug("server setting client %s %s", offer_sdp.type, offer_sdp.sdp)
50+
await aio_as_trio(peer_connection.setRemoteDescription(offer_sdp))
51+
52+
logger.debug("server creating local answer")
53+
answer = await peer_connection.createAnswer()
54+
logger.debug("server created local answer")
55+
munged_answer = SDP.munge_offer(answer, ufrag)
56+
logger.debug("server setting local description %s", munged_answer.sdp)
57+
await aio_as_trio(peer_connection.setLocalDescription(munged_answer))
58+
59+
# TODO: Fix this
60+
# Wait for handshake channel to open
61+
if handshake_channel.readyState != "open":
62+
logger.debug(
63+
"%s wait for handshake channel to open, starting status %s",
64+
role,
65+
handshake_channel.readyState,
66+
)
67+
# Wait for the 'open' event or signal cancellation
68+
open_event = trio.Event()
69+
70+
def on_open():
71+
open_event.set()
72+
73+
handshake_channel.on("open", on_open)
74+
with trio.move_on_after(30): # 30s timeout
75+
await open_event.wait()
76+
if handshake_channel.readyState != "open":
77+
raise Exception("Handshake data channel did not open in time")
78+
79+
logger.debug("%s handshake channel opened", role)
80+
81+
if role == "server":
82+
remote_fingerprint = peer_connection.remoteFingerprint().value
83+
remote_addr = fingerprint_to_multiaddr(remote_fingerprint)
84+
85+
# Get local fingerprint
86+
local_desc = peer_connection.localDescription
87+
local_fingerprint = SDP.get_fingerprint_from_sdp(local_desc.sdp)
88+
if local_fingerprint is None:
89+
raise Exception("Could not get fingerprint from local description sdp")
90+
91+
logger.debug("%s performing noise handshake", role)
92+
#TODO: Complete the noise handshake and connection authentication
93+
noiseProlouge = generate_noise_prologue(local_fingerprint, remote_addr, role)
94+
95+
except Exception as e:
96+
logger.error("%s noise handshake failed: %s", role, e)
97+
raise
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from aiortc import (RTCConfiguration, RTCPeerConnection, RTCSessionDescription, RTCDtlsFingerprint)
2+
from trio_asyncio import aio_as_trio
3+
from dataclasses import dataclass
4+
from .gen_certificate import WebRTCCertificate
5+
import datetime
6+
from ..constants import MAX_MESSAGE_SIZE
7+
8+
@dataclass
9+
class DirectRTCConfiguration:
10+
ufrag: str
11+
peer_connection: RTCPeerConnection
12+
rtc_config: RTCConfiguration
13+
14+
class DirectPeerConnection(RTCPeerConnection):
15+
def __init__(self, direct_config: DirectRTCConfiguration):
16+
self.ufrag = direct_config.ufrag
17+
self.peer_connection = direct_config.peer_connection
18+
super().__init__(direct_config.rtc_config)
19+
20+
async def createOffer(self) -> RTCSessionDescription:
21+
"""
22+
Create SDP offer, patching ICE ufrag and pwd to self.ufrag and self.upwd,
23+
set as local description, and return the patched RTCSessionDescription.
24+
"""
25+
offer = await aio_as_trio(super().createOffer())
26+
27+
sdp_lines = offer.sdp.splitlines()
28+
new_lines = []
29+
for line in sdp_lines:
30+
if line.startswith("a=ice-ufrag:"):
31+
new_lines.append(f"a=ice-ufrag:{getattr(self, 'ufrag', self.ufrag)}")
32+
elif line.startswith("a=ice-pwd:"):
33+
new_lines.append(f"a=ice-pwd:{getattr(self, 'ufrag', self.ufrag)}")
34+
else:
35+
new_lines.append(line)
36+
patched_sdp = "\r\n".join(new_lines) + "\r\n"
37+
38+
patched_offer = RTCSessionDescription(sdp=patched_sdp, type=offer.type)
39+
await aio_as_trio(self.setLocalDescription(patched_offer))
40+
return patched_offer
41+
42+
async def createAnswer(self) -> RTCSessionDescription:
43+
"""
44+
Create SDP answer, patching ICE ufrag and pwd to self.ufrag and self.upwd,
45+
set as local description, and return the patched RTCSessionDescription.
46+
"""
47+
answer = await aio_as_trio(super().createAnswer())
48+
49+
sdp_lines = answer.sdp.splitlines()
50+
new_lines = []
51+
for line in sdp_lines:
52+
if line.startswith("a=ice-ufrag:"):
53+
new_lines.append(f"a=ice-ufrag:{getattr(self, 'ufrag', self.ufrag)}")
54+
elif line.startswith("a=ice-pwd:"):
55+
new_lines.append(f"a=ice-pwd:{getattr(self, 'ufrag', self.ufrag)}")
56+
else:
57+
new_lines.append(line)
58+
patched_sdp = "\r\n".join(new_lines) + "\r\n"
59+
60+
patched_answer = RTCSessionDescription(sdp=patched_sdp, type=answer.type)
61+
await aio_as_trio(self.setLocalDescription(patched_answer))
62+
return patched_answer
63+
64+
65+
def remoteFingerprint(self) -> RTCDtlsFingerprint:
66+
pass
67+
# return self.peer_connection.
68+
69+
@staticmethod
70+
async def create_dialer_rtc_peer_connection(
71+
role: str,
72+
ufrag: str,
73+
rtc_configuration: RTCConfiguration,
74+
certificate: WebRTCCertificate | None = None,
75+
):
76+
"""
77+
Create a DirectRTCPeerConnection for dialing, similar to the JS createDialerRTCPeerConnection.
78+
"""
79+
80+
if certificate is None:
81+
certificate = WebRTCCertificate.generate()
82+
83+
# TODO: ICE servers. Should we use the ones from the rtc_configuration?
84+
85+
# # ICE servers
86+
# ice_servers = rtc_config.get("iceServers") if isinstance(rtc_config, dict) else getattr(rtc_config, "iceServers", None)
87+
# if ice_servers is None and default_ice_servers is not None:
88+
# ice_servers = default_ice_servers
89+
90+
# if map_ice_servers is not None:
91+
# mapped_ice_servers = map_ice_servers(ice_servers)
92+
# else:
93+
# mapped_ice_servers = ice_servers
94+
95+
peer_connection = RTCPeerConnection(
96+
RTCConfiguration(
97+
f"{role}-{(datetime.datetime.now(datetime.timezone.utc).timestamp() * 1000)}",
98+
disable_fingerprint_verification=True,
99+
disable_auto_negotiation=True,
100+
certificate_pem_file=certificate.to_pem()[0],
101+
key_pem_file=certificate.to_pem()[1],
102+
enable_ice_udp_mux=(role == "server"),
103+
max_message_size=MAX_MESSAGE_SIZE,
104+
# ice_servers=mapped_ice_servers,
105+
)
106+
)
107+
return DirectPeerConnection(DirectRTCConfiguration(ufrag, peer_connection, rtc_configuration))

0 commit comments

Comments
 (0)