14
14
from libp2p .abc import INetStream , IRawConnection
15
15
from libp2p .peer .id import ID
16
16
17
+ from .pb import Message
17
18
from ..async_bridge import TrioSafeWebRTCOperations
18
19
from ..connection import WebRTCRawConnection
19
20
from ..constants import (
23
24
WebRTCError ,
24
25
)
25
26
27
+ from libp2p .abc import (
28
+ IHost
29
+ )
30
+
26
31
logger = logging .getLogger ("webrtc.private.initiate_connection" )
27
32
28
33
29
34
async def initiate_connection (
30
35
maddr : Multiaddr ,
31
36
rtc_config : RTCConfiguration ,
32
- host : Any ,
37
+ host : IHost ,
33
38
timeout : float = DEFAULT_DIAL_TIMEOUT ,
34
39
) -> IRawConnection :
35
40
"""
@@ -80,30 +85,16 @@ async def initiate_connection(
80
85
81
86
logger .info ("Created RTCPeerConnection and data channel" )
82
87
83
- # Setup ICE candidate collection
84
- ice_candidates : list [Any ] = []
85
- ice_gathering_complete = trio .Event ()
86
-
87
- def on_ice_candidate (candidate : Any ) -> None :
88
- if candidate :
89
- ice_candidates .append (candidate )
90
- logger .debug (f"Generated ICE candidate: { candidate .candidate } " )
91
- else :
92
- # End of candidates signaled
93
- ice_gathering_complete .set ()
94
- logger .debug ("ICE gathering complete" )
95
-
96
- # Register ICE candidate handler
97
- peer_connection .on ("icecandidate" , on_ice_candidate )
98
-
99
88
# Setup data channel ready event
100
89
data_channel_ready = trio .Event ()
101
-
102
- def on_data_channel_open () -> None :
90
+
91
+ @data_channel .on ("open" )
92
+ def on_data_channel_open ():
103
93
logger .info ("Data channel opened" )
104
94
data_channel_ready .set ()
105
95
106
- def on_data_channel_error (error : Any ) -> None :
96
+ @data_channel .on ("error" )
97
+ def on_data_channel_error (error : Any ):
107
98
logger .error (f"Data channel error: { error } " )
108
99
109
100
# Register data channel event handlers
@@ -118,36 +109,44 @@ def on_data_channel_error(error: Any) -> None:
118
109
119
110
# Wait for ICE gathering to complete
120
111
with trio .move_on_after (timeout ):
121
- await ice_gathering_complete .wait ()
112
+ while peer_connection .iceGatheringState != "complete" :
113
+ await trio .sleep (0.05 )
122
114
115
+ logger .debug ("Sending SDP_offer to peer as initiator" )
123
116
# Send offer with all ICE candidates
124
- offer_msg = {"type" : "offer" , "sdp" : offer .sdp , "sdpType" : "offer" }
117
+ offer_msg = Message ()
118
+ offer_msg .type = Message .SDP_OFFER
119
+ offer_msg .data = offer .sdp
125
120
await _send_signaling_message (signaling_stream , offer_msg )
126
121
127
- # Send ICE candidates
128
- for candidate in ice_candidates :
129
- candidate_msg = {
130
- "type" : "ice-candidate" ,
131
- "candidate" : candidate .candidate ,
132
- "sdpMid" : candidate .sdpMid ,
133
- "sdpMLineIndex" : candidate .sdpMLineIndex ,
134
- }
135
- await _send_signaling_message (signaling_stream , candidate_msg )
136
-
137
- # Signal end of candidates
138
- end_msg = {"type" : "ice-candidate-end" }
139
- await _send_signaling_message (signaling_stream , end_msg )
140
-
122
+ # (Note: aiortc does not emit ice candidate event, per candidate (like js)
123
+ # but sends it along SDP. To maintain interop, we extract adn resend in given format )
124
+
125
+ # get SDP from local_descriptor for extracting ice_candidates
126
+ sdp = peer_connection .localDescription .sdp
127
+ # extract ice_candidates from sdp
128
+ ice_candidate = Message ()
129
+ ice_candidate .type = Message .ICE_CANDIDATE
130
+ for line in sdp .splitlines ():
131
+ if line .startswith ("a=candidate:" ):
132
+ candidate_line = line [len ("a=" ):]
133
+ # Need to make sure the candidate_line matches with expected cadidate in js
134
+ logger .Debug ("Candidate sent to peer: " , candidate_line )
135
+ ice_candidate .data = candidate_line
136
+ await _send_signaling_message (signaling_stream , ice_candidate )
137
+
138
+ ice_candidate .data = "" # Empty string signals
139
+ await _send_signaling_message (signaling_stream , ice_candidate )
141
140
logger .info ("Sent offer and ICE candidates" )
142
141
143
142
# Wait for answer
144
143
answer_msg = await _receive_signaling_message (signaling_stream , timeout )
145
- if answer_msg .get ( " type" ) != "answer" :
146
- raise SDPHandshakeError (f"Expected answer, got: { answer_msg .get ( ' type' ) } " )
144
+ if answer_msg .type != Message . SDP_ANSWER :
145
+ raise SDPHandshakeError (f"Expected answer, got: { answer_msg .type } " )
147
146
148
147
# Set remote description
149
148
answer = RTCSessionDescription (
150
- sdp = answer_msg [ "sdp" ] , type = answer_msg [ "sdpType" ]
149
+ sdp = answer_msg . data , type = 'answer'
151
150
)
152
151
bridge = TrioSafeWebRTCOperations ._get_bridge ()
153
152
async with bridge :
@@ -162,7 +161,6 @@ def on_data_channel_error(error: Any) -> None:
162
161
163
162
# Wait for data channel to be ready
164
163
connection_failed = trio .Event ()
165
-
166
164
def on_connection_state_change () -> None :
167
165
state = peer_connection .connectionState
168
166
logger .debug (f"Connection state: { state } " )
@@ -198,6 +196,9 @@ def on_connection_state_change() -> None:
198
196
is_initiator = True ,
199
197
)
200
198
199
+ logger .debug ('initiator connected, closing init channel' )
200
+ data_channel .close ()
201
+
201
202
logger .info (f"Successfully established WebRTC connection to { target_peer_id } " )
202
203
return connection
203
204
@@ -220,39 +221,29 @@ def on_connection_state_change() -> None:
220
221
raise WebRTCError (f"Connection initiation failed: { e } " ) from e
221
222
222
223
223
- async def _send_signaling_message (stream : INetStream , message : dict [ str , Any ] ) -> None :
224
+ async def _send_signaling_message (stream : INetStream , message : Message ) -> None :
224
225
"""Send a signaling message over the stream"""
225
226
try :
226
- message_data = json .dumps (message ).encode ("utf-8" )
227
- message_length = len (message_data ).to_bytes (4 , byteorder = "big" )
228
- await stream .write (message_length + message_data )
229
- logger .debug (f"Sent signaling message: { message ['type' ]} " )
227
+ # message_length = len(message_data).to_bytes(4, byteorder="big")
228
+ await stream .write (message .SerializeToString ())
229
+ logger .debug (f"Sent signaling message: { message .type } " )
230
230
except Exception as e :
231
231
logger .error (f"Failed to send signaling message: { e } " )
232
232
raise
233
233
234
234
235
235
async def _receive_signaling_message (
236
236
stream : INetStream , timeout : float
237
- ) -> dict [ str , Any ] :
237
+ ) -> Message :
238
238
"""Receive a signaling message from the stream"""
239
239
try :
240
240
with trio .move_on_after (timeout ) as cancel_scope :
241
- # Read message length
242
- length_data = await stream .read (4 )
243
- if len (length_data ) != 4 :
244
- raise WebRTCError ("Failed to read message length" )
245
-
246
- message_length = int .from_bytes (length_data , byteorder = "big" )
247
-
248
241
# Read message data
249
- message_data = await stream .read (message_length )
250
- if len (message_data ) != message_length :
251
- raise WebRTCError ("Failed to read complete message" )
252
-
253
- message = json .loads (message_data .decode ("utf-8" ))
254
- logger .debug (f"Received signaling message: { message .get ('type' )} " )
255
- return message
242
+ message_data = await stream .read ()
243
+ deserealized_msg = Message ()
244
+ deserealized_msg .ParseFromString (message_data )
245
+ logger .debug (f"Received signaling message: { deserealized_msg .type } " )
246
+ return deserealized_msg
256
247
257
248
if cancel_scope .cancelled_caught :
258
249
raise WebRTCError ("Signaling message receive timeout" )
0 commit comments