Skip to content

Commit 5d03f4f

Browse files
committed
Fix SDP and ICE message transfer
1 parent e7c1910 commit 5d03f4f

File tree

6 files changed

+170
-67
lines changed

6 files changed

+170
-67
lines changed

libp2p/transport/webrtc/private_to_private/initiate_connection.py

Lines changed: 51 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from libp2p.abc import INetStream, IRawConnection
1515
from libp2p.peer.id import ID
1616

17+
from .pb import Message
1718
from ..async_bridge import TrioSafeWebRTCOperations
1819
from ..connection import WebRTCRawConnection
1920
from ..constants import (
@@ -23,13 +24,17 @@
2324
WebRTCError,
2425
)
2526

27+
from libp2p.abc import (
28+
IHost
29+
)
30+
2631
logger = logging.getLogger("webrtc.private.initiate_connection")
2732

2833

2934
async def initiate_connection(
3035
maddr: Multiaddr,
3136
rtc_config: RTCConfiguration,
32-
host: Any,
37+
host: IHost,
3338
timeout: float = DEFAULT_DIAL_TIMEOUT,
3439
) -> IRawConnection:
3540
"""
@@ -80,30 +85,16 @@ async def initiate_connection(
8085

8186
logger.info("Created RTCPeerConnection and data channel")
8287

83-
# Setup ICE candidate collection
84-
ice_candidates: list[Any] = []
85-
ice_gathering_complete = trio.Event()
86-
87-
def on_ice_candidate(candidate: Any) -> None:
88-
if candidate:
89-
ice_candidates.append(candidate)
90-
logger.debug(f"Generated ICE candidate: {candidate.candidate}")
91-
else:
92-
# End of candidates signaled
93-
ice_gathering_complete.set()
94-
logger.debug("ICE gathering complete")
95-
96-
# Register ICE candidate handler
97-
peer_connection.on("icecandidate", on_ice_candidate)
98-
9988
# Setup data channel ready event
10089
data_channel_ready = trio.Event()
101-
102-
def on_data_channel_open() -> None:
90+
91+
@data_channel.on("open")
92+
def on_data_channel_open():
10393
logger.info("Data channel opened")
10494
data_channel_ready.set()
10595

106-
def on_data_channel_error(error: Any) -> None:
96+
@data_channel.on("error")
97+
def on_data_channel_error(error: Any):
10798
logger.error(f"Data channel error: {error}")
10899

109100
# Register data channel event handlers
@@ -118,36 +109,44 @@ def on_data_channel_error(error: Any) -> None:
118109

119110
# Wait for ICE gathering to complete
120111
with trio.move_on_after(timeout):
121-
await ice_gathering_complete.wait()
112+
while peer_connection.iceGatheringState != "complete":
113+
await trio.sleep(0.05)
122114

115+
logger.debug("Sending SDP_offer to peer as initiator")
123116
# Send offer with all ICE candidates
124-
offer_msg = {"type": "offer", "sdp": offer.sdp, "sdpType": "offer"}
117+
offer_msg = Message()
118+
offer_msg.type = Message.SDP_OFFER
119+
offer_msg.data = offer.sdp
125120
await _send_signaling_message(signaling_stream, offer_msg)
126121

127-
# Send ICE candidates
128-
for candidate in ice_candidates:
129-
candidate_msg = {
130-
"type": "ice-candidate",
131-
"candidate": candidate.candidate,
132-
"sdpMid": candidate.sdpMid,
133-
"sdpMLineIndex": candidate.sdpMLineIndex,
134-
}
135-
await _send_signaling_message(signaling_stream, candidate_msg)
136-
137-
# Signal end of candidates
138-
end_msg = {"type": "ice-candidate-end"}
139-
await _send_signaling_message(signaling_stream, end_msg)
140-
122+
# (Note: aiortc does not emit ice candidate event, per candidate (like js)
123+
# but sends it along SDP. To maintain interop, we extract adn resend in given format )
124+
125+
# get SDP from local_descriptor for extracting ice_candidates
126+
sdp = peer_connection.localDescription.sdp
127+
# extract ice_candidates from sdp
128+
ice_candidate =Message()
129+
ice_candidate.type = Message.ICE_CANDIDATE
130+
for line in sdp.splitlines():
131+
if line.startswith("a=candidate:"):
132+
candidate_line = line[len("a="):]
133+
# Need to make sure the candidate_line matches with expected cadidate in js
134+
logger.Debug("Candidate sent to peer: ", candidate_line)
135+
ice_candidate.data = candidate_line
136+
await _send_signaling_message(signaling_stream, ice_candidate)
137+
138+
ice_candidate.data = "" # Empty string signals
139+
await _send_signaling_message(signaling_stream, ice_candidate)
141140
logger.info("Sent offer and ICE candidates")
142141

143142
# Wait for answer
144143
answer_msg = await _receive_signaling_message(signaling_stream, timeout)
145-
if answer_msg.get("type") != "answer":
146-
raise SDPHandshakeError(f"Expected answer, got: {answer_msg.get('type')}")
144+
if answer_msg.type != Message.SDP_ANSWER:
145+
raise SDPHandshakeError(f"Expected answer, got: {answer_msg.type}")
147146

148147
# Set remote description
149148
answer = RTCSessionDescription(
150-
sdp=answer_msg["sdp"], type=answer_msg["sdpType"]
149+
sdp=answer_msg.data, type='answer'
151150
)
152151
bridge = TrioSafeWebRTCOperations._get_bridge()
153152
async with bridge:
@@ -162,7 +161,6 @@ def on_data_channel_error(error: Any) -> None:
162161

163162
# Wait for data channel to be ready
164163
connection_failed = trio.Event()
165-
166164
def on_connection_state_change() -> None:
167165
state = peer_connection.connectionState
168166
logger.debug(f"Connection state: {state}")
@@ -198,6 +196,9 @@ def on_connection_state_change() -> None:
198196
is_initiator=True,
199197
)
200198

199+
logger.debug('initiator connected, closing init channel')
200+
data_channel.close()
201+
201202
logger.info(f"Successfully established WebRTC connection to {target_peer_id}")
202203
return connection
203204

@@ -220,39 +221,29 @@ def on_connection_state_change() -> None:
220221
raise WebRTCError(f"Connection initiation failed: {e}") from e
221222

222223

223-
async def _send_signaling_message(stream: INetStream, message: dict[str, Any]) -> None:
224+
async def _send_signaling_message(stream: INetStream, message: Message) -> None:
224225
"""Send a signaling message over the stream"""
225226
try:
226-
message_data = json.dumps(message).encode("utf-8")
227-
message_length = len(message_data).to_bytes(4, byteorder="big")
228-
await stream.write(message_length + message_data)
229-
logger.debug(f"Sent signaling message: {message['type']}")
227+
# message_length = len(message_data).to_bytes(4, byteorder="big")
228+
await stream.write(message.SerializeToString())
229+
logger.debug(f"Sent signaling message: {message.type}")
230230
except Exception as e:
231231
logger.error(f"Failed to send signaling message: {e}")
232232
raise
233233

234234

235235
async def _receive_signaling_message(
236236
stream: INetStream, timeout: float
237-
) -> dict[str, Any]:
237+
) -> Message:
238238
"""Receive a signaling message from the stream"""
239239
try:
240240
with trio.move_on_after(timeout) as cancel_scope:
241-
# Read message length
242-
length_data = await stream.read(4)
243-
if len(length_data) != 4:
244-
raise WebRTCError("Failed to read message length")
245-
246-
message_length = int.from_bytes(length_data, byteorder="big")
247-
248241
# Read message data
249-
message_data = await stream.read(message_length)
250-
if len(message_data) != message_length:
251-
raise WebRTCError("Failed to read complete message")
252-
253-
message = json.loads(message_data.decode("utf-8"))
254-
logger.debug(f"Received signaling message: {message.get('type')}")
255-
return message
242+
message_data = await stream.read()
243+
deserealized_msg = Message()
244+
deserealized_msg.ParseFromString(message_data)
245+
logger.debug(f"Received signaling message: {deserealized_msg.type}")
246+
return deserealized_msg
256247

257248
if cancel_scope.cancelled_caught:
258249
raise WebRTCError("Signaling message receive timeout")
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
Protocol buffer package for webrtc_private_to_private.
3+
4+
Contains generated protobuf code for webrtc_private_to_private protocol.
5+
"""
6+
7+
# Import the classes to be accessible directly from the package
8+
from .message_pb2 import (
9+
Message,
10+
)
11+
12+
__all__ = ["Message"]

libp2p/transport/webrtc/private_to_private/pb/message_pb2.py

Lines changed: 38 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
@generated by mypy-protobuf. Do not edit manually!
3+
isort:skip_file
4+
"""
5+
import builtins
6+
import google.protobuf.descriptor
7+
import google.protobuf.internal.enum_type_wrapper
8+
import google.protobuf.message
9+
import sys
10+
import typing
11+
12+
if sys.version_info >= (3, 10):
13+
import typing as typing_extensions
14+
else:
15+
import typing_extensions
16+
17+
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
18+
19+
class Message(google.protobuf.message.Message):
20+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
21+
22+
class _Type:
23+
ValueType = typing.NewType("ValueType", builtins.int)
24+
V: typing_extensions.TypeAlias = ValueType
25+
26+
class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._Type.ValueType], builtins.type): # noqa: F821
27+
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
28+
SDP_OFFER: Message._Type.ValueType # 0
29+
"""String of `RTCSessionDescription.sdp`"""
30+
SDP_ANSWER: Message._Type.ValueType # 1
31+
"""String of `RTCSessionDescription.sdp`"""
32+
ICE_CANDIDATE: Message._Type.ValueType # 2
33+
"""String of `RTCIceCandidate.toJSON()`"""
34+
35+
class Type(_Type, metaclass=_TypeEnumTypeWrapper):
36+
"""Specifies type in `data` field."""
37+
38+
SDP_OFFER: Message.Type.ValueType # 0
39+
"""String of `RTCSessionDescription.sdp`"""
40+
SDP_ANSWER: Message.Type.ValueType # 1
41+
"""String of `RTCSessionDescription.sdp`"""
42+
ICE_CANDIDATE: Message.Type.ValueType # 2
43+
"""String of `RTCIceCandidate.toJSON()`"""
44+
45+
TYPE_FIELD_NUMBER: builtins.int
46+
DATA_FIELD_NUMBER: builtins.int
47+
type: global___Message.Type.ValueType
48+
data: builtins.str
49+
def __init__(
50+
self,
51+
*,
52+
type: global___Message.Type.ValueType | None = ...,
53+
data: builtins.str | None = ...,
54+
) -> None: ...
55+
def HasField(self, field_name: typing_extensions.Literal["_data", b"_data", "_type", b"_type", "data", b"data", "type", b"type"]) -> builtins.bool: ...
56+
def ClearField(self, field_name: typing_extensions.Literal["_data", b"_data", "_type", b"_type", "data", b"data", "type", b"type"]) -> None: ...
57+
@typing.overload
58+
def WhichOneof(self, oneof_group: typing_extensions.Literal["_data", b"_data"]) -> typing_extensions.Literal["data"] | None: ...
59+
@typing.overload
60+
def WhichOneof(self, oneof_group: typing_extensions.Literal["_type", b"_type"]) -> typing_extensions.Literal["type"] | None: ...
61+
62+
global___Message = Message

libp2p/transport/webrtc/private_to_private/signaling_stream_handler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import logging
32
from typing import Any
43

@@ -14,7 +13,7 @@
1413
from libp2p.abc import INetStream, IRawConnection
1514
from libp2p.crypto.ed25519 import create_new_key_pair
1615
from libp2p.peer.id import ID
17-
16+
from .pb import Message
1817
from ..connection import WebRTCRawConnection
1918
from ..constants import WebRTCError
2019

@@ -71,8 +70,8 @@ def on_channel_open() -> None:
7170
offer_data = await stream.read()
7271
if not offer_data:
7372
raise WebRTCError("No offer data received")
74-
75-
offer_message = json.loads(offer_data.decode("utf-8"))
73+
offer_message = Message()
74+
offer_message.ParseFromString(offer_data)
7675
if offer_message.get("type") != "offer":
7776
raise WebRTCError(f"Expected offer, got: {offer_message.get('type')}")
7877

@@ -96,9 +95,10 @@ def on_channel_open() -> None:
9695

9796
# Send answer back
9897
try:
99-
answer_message = {"type": answer.type, "sdp": answer.sdp}
100-
answer_data = json.dumps(answer_message).encode("utf-8")
101-
await stream.write(answer_data)
98+
answer_message = Message()
99+
answer_message.type = Message.SDP_ANSWER
100+
answer_message.data = answer.sdp
101+
await stream.write(answer_message)
102102
logger.info("Sent SDP answer")
103103

104104
except Exception as e:

0 commit comments

Comments
 (0)