From 3ee28abd828857c380b04714647edd7ed6515f82 Mon Sep 17 00:00:00 2001 From: Neha Kumari Date: Fri, 4 Apr 2025 02:00:38 +0530 Subject: [PATCH 1/9] initial-setup --- libp2p/transport/webrtc/__init__.py | 0 libp2p/transport/webrtc/signal.py | 32 ++++++++ libp2p/transport/webrtc/webrtc.py | 119 ++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+) create mode 100644 libp2p/transport/webrtc/__init__.py create mode 100644 libp2p/transport/webrtc/signal.py create mode 100644 libp2p/transport/webrtc/webrtc.py diff --git a/libp2p/transport/webrtc/__init__.py b/libp2p/transport/webrtc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libp2p/transport/webrtc/signal.py b/libp2p/transport/webrtc/signal.py new file mode 100644 index 000000000..d9db9d9f5 --- /dev/null +++ b/libp2p/transport/webrtc/signal.py @@ -0,0 +1,32 @@ +import asyncio +import websockets +import json + +connected_peers = {} + +async def signaling_handler(websocket, path): + async for message in websocket: + data = json.loads(message) + + if data["type"] == "register": + peer_id = data["peer_id"] + connected_peers[peer_id] = websocket + print(f"Peer {peer_id} registered.") + + elif data["type"] == "offer": + target = data["target"] + if target in connected_peers: + await connected_peers[target].send(json.dumps(data)) + + elif data["type"] == "answer": + target = data["target"] + if target in connected_peers: + await connected_peers[target].send(json.dumps(data)) + +async def start_signaling_server(): + server = await websockets.serve(signaling_handler, "localhost", 8765) + print("Signaling server started on ws://localhost:8765") + await server.wait_closed() + +if __name__ == "__main__": + asyncio.run(start_signaling_server()) diff --git a/libp2p/transport/webrtc/webrtc.py b/libp2p/transport/webrtc/webrtc.py new file mode 100644 index 000000000..2ef7ec10a --- /dev/null +++ b/libp2p/transport/webrtc/webrtc.py @@ -0,0 +1,119 @@ +import asyncio +import json +import websockets +import logging +from aiortc import RTCIceCandidate, RTCPeerConnection, RTCSessionDescription, RTCDataChannel +from libp2p.pubsub.gossipsub import GossipSub +from libp2p.pubsub.pubsub import Pubsub +from libp2p.host.basic_host import BasicHost +from multiaddr import Multiaddr + +SIGNALING_SERVER_URL = "ws://localhost:8765" + +# Initialize logger +logger = logging.getLogger("webrtc-transport") +logging.basicConfig(level=logging.INFO) + +class WebRTCTransport: + def __init__(self, peer_id, host: BasicHost): + self.peer_id = peer_id + self.host = host + self.peer_connection = RTCPeerConnection() + self.data_channel = None + self.pubsub = None + self.websocket = None + + async def connect_signaling_server(self): + """Connects to the WebSocket-based signaling server""" + try: + self.websocket = await websockets.connect(SIGNALING_SERVER_URL) + await self.websocket.send(json.dumps({"type": "register", "peer_id": self.peer_id})) + asyncio.create_task(self.listen_signaling()) + except Exception as e: + logger.error(f"Failed to connect to signaling server: {e}") + + async def listen_signaling(self): + """Listens for incoming SDP offers and answers""" + try: + async for message in self.websocket: + data = json.loads(message) + if data["type"] == "offer": + await self.handle_offer(data) + elif data["type"] == "answer": + await self.handle_answer(data) + except Exception as e: + logger.error(f"Error in signaling listener: {e}") + + async def handle_offer(self, data): + """Handles incoming SDP offers""" + offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"]) + await self.peer_connection.setRemoteDescription(offer) + + # Create an answer + answer = await self.peer_connection.createAnswer() + await self.peer_connection.setLocalDescription(answer) + + await self.websocket.send(json.dumps({ + "type": "answer", + "target": data["peer_id"], + "sdp": answer.sdp + })) + + async def handle_answer(self, data): + """Handles incoming SDP answers""" + answer = RTCSessionDescription(sdp=data["sdp"], type=data["type"]) + await self.peer_connection.setRemoteDescription(answer) + + async def create_data_channel(self): + """Creates and opens a WebRTC data channel""" + self.data_channel = self.peer_connection.createDataChannel("libp2p-webrtc") + + @self.data_channel.on("open") + def on_open(): + logger.info(f"Data channel open with peer {self.peer_id}") + + @self.data_channel.on("message") + def on_message(message): + logger.info(f"Received message from peer {self.peer_id}: {message}") + + async def initiate_connection(self, target_peer_id): + """Initiates connection with a peer""" + self.data_channel = self.peer_connection.createDataChannel("libp2p-webrtc") + + offer = await self.peer_connection.createOffer() + await self.peer_connection.setLocalDescription(offer) + + await self.websocket.send(json.dumps({ + "type": "offer", + "peer_id": self.peer_id, + "target": target_peer_id, + "sdp": offer.sdp + })) + + async def start_peer_discovery(self): + """Starts peer discovery using GossipSub""" + gossipsub = GossipSub() + self.pubsub = Pubsub(self.host, gossipsub, None) + + topic = await self.pubsub.subscribe("webrtc-peer-discovery") + + async def handle_message(): + async for msg in topic: + logger.info(f"Discovered Peer: {msg.data.decode()}") + + asyncio.create_task(handle_message()) + + # Advertise this peer + await self.pubsub.publish("webrtc-peer-discovery", self.peer_id.encode()) + +# Multiaddr Parsing & Validation +def parse_webrtc_multiaddr(multiaddr_str): + """Parse and validate a WebRTC multiaddr.""" + try: + addr = Multiaddr(multiaddr_str) + if "/webrtc" not in [p.name for p in addr.protocols()]: + raise ValueError("Invalid WebRTC multiaddr: Missing /webrtc protocol") + return addr + except Exception as e: + logger.error(f"Failed to parse multiaddr: {e}") + return None From 1ea3a049d8d6bd1bc2734e1a1d19b631a5080717 Mon Sep 17 00:00:00 2001 From: Neha Kumari Date: Sun, 13 Apr 2025 22:32:34 +0530 Subject: [PATCH 2/9] improve-webrtc-and-integrate-loopback_test --- libp2p/transport/webrtc/signal.py | 32 --- libp2p/transport/webrtc/test_loopback.py | 72 ++++++ libp2p/transport/webrtc/webrtc.py | 209 ++++++++++++------ .../webrtc/webrtc_signal_protocol.py | 33 +++ 4 files changed, 249 insertions(+), 97 deletions(-) delete mode 100644 libp2p/transport/webrtc/signal.py create mode 100644 libp2p/transport/webrtc/test_loopback.py create mode 100644 libp2p/transport/webrtc/webrtc_signal_protocol.py diff --git a/libp2p/transport/webrtc/signal.py b/libp2p/transport/webrtc/signal.py deleted file mode 100644 index d9db9d9f5..000000000 --- a/libp2p/transport/webrtc/signal.py +++ /dev/null @@ -1,32 +0,0 @@ -import asyncio -import websockets -import json - -connected_peers = {} - -async def signaling_handler(websocket, path): - async for message in websocket: - data = json.loads(message) - - if data["type"] == "register": - peer_id = data["peer_id"] - connected_peers[peer_id] = websocket - print(f"Peer {peer_id} registered.") - - elif data["type"] == "offer": - target = data["target"] - if target in connected_peers: - await connected_peers[target].send(json.dumps(data)) - - elif data["type"] == "answer": - target = data["target"] - if target in connected_peers: - await connected_peers[target].send(json.dumps(data)) - -async def start_signaling_server(): - server = await websockets.serve(signaling_handler, "localhost", 8765) - print("Signaling server started on ws://localhost:8765") - await server.wait_closed() - -if __name__ == "__main__": - asyncio.run(start_signaling_server()) diff --git a/libp2p/transport/webrtc/test_loopback.py b/libp2p/transport/webrtc/test_loopback.py new file mode 100644 index 000000000..0ecdd2d9a --- /dev/null +++ b/libp2p/transport/webrtc/test_loopback.py @@ -0,0 +1,72 @@ +import logging +import trio +from multiaddr import Multiaddr +from libp2p.peer.id import ID +from libp2p.security.noise.transport import Transport as NoiseTransport +from libp2p.transport.tcp.tcp import TCP +from libp2p.transport.upgrader import TransportUpgrader +from libp2p.stream_muxer.mplex.mplex import Mplex +from libp2p.host.basic_host import BasicHost +from libp2p.network.swarm import Swarm +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.security.noise.transport import PROTOCOL_ID +from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID +from libp2p.peer.peerstore import PeerStore +from .webrtc import WebRTCTransport + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("webrtc-loopback-test") + +async def build_host(name): + key_pair = create_new_key_pair() + peer_id = ID.from_pubkey(key_pair.public_key) + logger.info(f"Peer {name} ID: {peer_id}") + + base_transport = TCP() + peer_store= PeerStore() + secure_transports = { + PROTOCOL_ID: NoiseTransport(libp2p_keypair=key_pair), + } + muxer_transports = { + MPLEX_PROTOCOL_ID: Mplex + } + + upgrader = TransportUpgrader(secure_transports, muxer_transports) + swarm = Swarm(peer_id, peer_store, upgrader, base_transport) + host = BasicHost(swarm) + + await host.get_network().listen(Multiaddr(f"/ip4/127.0.0.1/tcp/9095/ws/p2p/{peer_id}")) + logger.info(f"Host {name} listening on: {host.get_network().listening_addrs()}") + return host, peer_id + +async def run_loopback_test(): + host_b, peer_id_b = await build_host("webrtc") + webrtc_b = WebRTCTransport(peer_id=peer_id_b, host=host_b) + + listener = await webrtc_b.create_listener(lambda conn: logger.info("[B] Listener ready")) + + async def act_as_server(): + conn = await listener.accept() + msg = await conn.read() + logger.info(f"[B] Got message: {msg.decode()}") + await conn.write(b"Reply from B") + + async def act_as_client(): + host_a, peer_id_a = await build_host("webrtc") + webrtc_a = WebRTCTransport(peer_id=peer_id_a, host=host_a) + + conn = await webrtc_a.dial(Multiaddr("/webrtc")) + logger.info("[A] Dial successful") + + await conn.write(b"Hello from A") + response = await conn.read() + logger.info(f"[A] Got response: {response.decode()}") + + async with trio.open_nursery() as nursery: + nursery.start_soon(act_as_server) + await trio.sleep(1) + nursery.start_soon(act_as_client) + +if __name__ == "__main__": + trio.run(run_loopback_test) + diff --git a/libp2p/transport/webrtc/webrtc.py b/libp2p/transport/webrtc/webrtc.py index 2ef7ec10a..e2f0fb524 100644 --- a/libp2p/transport/webrtc/webrtc.py +++ b/libp2p/transport/webrtc/webrtc.py @@ -1,69 +1,89 @@ -import asyncio +import trio import json -import websockets import logging from aiortc import RTCIceCandidate, RTCPeerConnection, RTCSessionDescription, RTCDataChannel from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub +from libp2p.abc import ITransport, IRawConnection, IListener +from typing import Callable +from libp2p.peer.id import ID from libp2p.host.basic_host import BasicHost from multiaddr import Multiaddr -SIGNALING_SERVER_URL = "ws://localhost:8765" - -# Initialize logger -logger = logging.getLogger("webrtc-transport") +logger = logging.getLogger("webrtc") logging.basicConfig(level=logging.INFO) +SIGNAL_PROTOCOL = "/libp2p/webrtc/signal/1.0.0" + +class WebRTCRawConnection(IRawConnection): + def __init__(self, channel: RTCDataChannel): + self.channel = channel + self.receive_channel, self.send_channel = trio.open_memory_channel(0) + + @channel.on("message") + def on_message(message): + self.send_channel.send_nowait(message) + + async def read(self, n: int = -1) -> bytes: + return await self.receive_channel.receive() + + async def write(self, data: bytes) -> None: + self.channel.send(data) + + async def close(self) -> None: + self.channel.close() + +class WebRTCListener(IListener): + def __init__(self, host, peer_id: ID): + self.host = host + self.peer_id = peer_id + self.conn_send_channel, self.conn_recv_channel = trio.open_memory_channel(0) + + async def listen(self, maddr: Multiaddr) -> None: + await self.host.set_stream_handler(SIGNAL_PROTOCOL, self._handle_stream) + + async def accept(self) -> IRawConnection: + return await self.conn_recv_channel.receive() + + async def close(self) -> None: + pass + + async def _handle_stream(self, stream) -> None: + pc = RTCPeerConnection() + channel_ready = trio.Event() + + @pc.on("datachannel") + def on_datachannel(channel): + @channel.on("open") + def opened(): + channel_ready.set() + + self.conn_send_channel.send_nowait(WebRTCRawConnection(channel)) + + offer_data = await stream.read() + offer_msg = json.loads(offer_data.decode()) + offer = RTCSessionDescription(**offer_msg) + await pc.setRemoteDescription(offer) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + await stream.write(json.dumps({ + "sdp": answer.sdp, + "type": answer.type + }).encode()) + + await channel_ready.wait() + -class WebRTCTransport: - def __init__(self, peer_id, host: BasicHost): +class WebRTCTransport(ITransport): + def __init__(self, peer_id, host: BasicHost, config=None): self.peer_id = peer_id self.host = host + self.config = config or {} self.peer_connection = RTCPeerConnection() self.data_channel = None self.pubsub = None - self.websocket = None - - async def connect_signaling_server(self): - """Connects to the WebSocket-based signaling server""" - try: - self.websocket = await websockets.connect(SIGNALING_SERVER_URL) - await self.websocket.send(json.dumps({"type": "register", "peer_id": self.peer_id})) - asyncio.create_task(self.listen_signaling()) - except Exception as e: - logger.error(f"Failed to connect to signaling server: {e}") - - async def listen_signaling(self): - """Listens for incoming SDP offers and answers""" - try: - async for message in self.websocket: - data = json.loads(message) - if data["type"] == "offer": - await self.handle_offer(data) - elif data["type"] == "answer": - await self.handle_answer(data) - except Exception as e: - logger.error(f"Error in signaling listener: {e}") - - async def handle_offer(self, data): - """Handles incoming SDP offers""" - offer = RTCSessionDescription(sdp=data["sdp"], type=data["type"]) - await self.peer_connection.setRemoteDescription(offer) - - # Create an answer - answer = await self.peer_connection.createAnswer() - await self.peer_connection.setLocalDescription(answer) - - await self.websocket.send(json.dumps({ - "type": "answer", - "target": data["peer_id"], - "sdp": answer.sdp - })) - - async def handle_answer(self, data): - """Handles incoming SDP answers""" - answer = RTCSessionDescription(sdp=data["sdp"], type=data["type"]) - await self.peer_connection.setRemoteDescription(answer) - + async def create_data_channel(self): """Creates and opens a WebRTC data channel""" self.data_channel = self.peer_connection.createDataChannel("libp2p-webrtc") @@ -76,19 +96,40 @@ def on_open(): def on_message(message): logger.info(f"Received message from peer {self.peer_id}: {message}") - async def initiate_connection(self, target_peer_id): - """Initiates connection with a peer""" - self.data_channel = self.peer_connection.createDataChannel("libp2p-webrtc") + async def handle_offer_from_peer(self, stream, data): + """Handle offer and send back answer on same stream""" + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await self.peer_connection.setRemoteDescription(offer) - offer = await self.peer_connection.createOffer() - await self.peer_connection.setLocalDescription(offer) + answer = await self.peer_connection.createAnswer() + await self.peer_connection.setLocalDescription(answer) - await self.websocket.send(json.dumps({ - "type": "offer", - "peer_id": self.peer_id, - "target": target_peer_id, - "sdp": offer.sdp - })) + response = { + "type": "answer", + "sdp": answer.sdp, + "sdpType": answer.type, + "peer_id": self.peer_id + } + + await stream.write(json.dumps(response).encode()) + + async def handle_answer_from_peer(self, data): + """Handle SDP answer from peer""" + answer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await self.peer_connection.setRemoteDescription(answer) + + async def handle_ice_candidate(self, data): + """Optional: ICE candidate support""" + candidate = RTCIceCandidate( + component=data["component"], + foundation=data["foundation"], + priority=data["priority"], + ip=data["ip"], + protocol=data["protocol"], + port=data["port"], + type=data["candidateType"] + ) + await self.peer_connection.addIceCandidate(candidate) async def start_peer_discovery(self): """Starts peer discovery using GossipSub""" @@ -101,12 +142,48 @@ async def handle_message(): async for msg in topic: logger.info(f"Discovered Peer: {msg.data.decode()}") - asyncio.create_task(handle_message()) + async with trio.open_nursery() as nursery: + nursery.start_soon(handle_message) - # Advertise this peer await self.pubsub.publish("webrtc-peer-discovery", self.peer_id.encode()) -# Multiaddr Parsing & Validation + async def create_listener(self, handler: Callable[[IRawConnection], None]) -> IListener: + listener = WebRTCListener(self.host, self.peer_id) + await listener.listen(Multiaddr(f"/ip4/147.28.186.157/tcp/9095/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit/webrtc")) + self.host.set_stream_handler(SIGNAL_PROTOCOL, listener._handle_stream) + return listener + + async def dial(self, maddr: Multiaddr) -> IRawConnection: + peer_id = parse_webrtc_multiaddr(maddr) + stream = await self.host.new_stream(peer_id, [SIGNAL_PROTOCOL]) + + pc = RTCPeerConnection() + channel = pc.createDataChannel("libp2p") + + channel_ready = trio.Event() + + @channel.on("open") + def on_open(): + channel_ready.set() + + offer = await pc.createOffer() + await pc.setLocalDescription(offer) + + await stream.write(json.dumps({ + "type": "offer", + "peer_id": self.peer_id, + "sdp": offer.sdp, + "sdpType": offer.type, + }).encode()) + + answer_data = await stream.read() + answer_msg = json.loads(answer_data.decode()) + answer = RTCSessionDescription(**answer_msg) + await pc.setRemoteDescription(answer) + + await channel_ready.wait() + + return WebRTCRawConnection(channel) def parse_webrtc_multiaddr(multiaddr_str): """Parse and validate a WebRTC multiaddr.""" try: @@ -117,3 +194,5 @@ def parse_webrtc_multiaddr(multiaddr_str): except Exception as e: logger.error(f"Failed to parse multiaddr: {e}") return None + + diff --git a/libp2p/transport/webrtc/webrtc_signal_protocol.py b/libp2p/transport/webrtc/webrtc_signal_protocol.py new file mode 100644 index 000000000..ebfe93578 --- /dev/null +++ b/libp2p/transport/webrtc/webrtc_signal_protocol.py @@ -0,0 +1,33 @@ +# webrtc_signal_protocol.py +# libp2p stream handler that dispatches incoming messages to the WebRTCTransport. + +import json +import logging + +logger = logging.getLogger("signal-protocol") +PROTOCOL_ID = "/libp2p/webrtc/signal/1.0.0" + +class WebRTCSignalingProtocol: + def __init__(self, transport): + self.transport = transport # Reference to WebRTCTransport + + async def handle_stream(self, stream): + """Handle incoming signaling messages on a libp2p stream""" + try: + while True: + data = await stream.read() + if not data: + break + + message = json.loads(data.decode()) + + msg_type = message.get("type") + if msg_type == "offer": + await self.transport.handle_offer_from_peer(stream, message) + elif msg_type == "answer": + await self.transport.handle_answer_from_peer(message) + elif msg_type == "ice": + await self.transport.handle_ice_candidate(message) + except Exception as e: + logger.error(f"Error handling signaling stream: {e}") + From e73b5199e480b40963b72ff9fd3500614d3d9e36 Mon Sep 17 00:00:00 2001 From: Neha Kumari Date: Mon, 14 Apr 2025 01:37:54 +0530 Subject: [PATCH 3/9] lint-fixes --- libp2p/transport/webrtc/test_loopback.py | 89 +++++++---- libp2p/transport/webrtc/webrtc.py | 147 +++++++++++------- .../webrtc/webrtc_signal_protocol.py | 14 +- 3 files changed, 162 insertions(+), 88 deletions(-) diff --git a/libp2p/transport/webrtc/test_loopback.py b/libp2p/transport/webrtc/test_loopback.py index 0ecdd2d9a..72122d7d0 100644 --- a/libp2p/transport/webrtc/test_loopback.py +++ b/libp2p/transport/webrtc/test_loopback.py @@ -1,61 +1,96 @@ import logging + +from multiaddr import ( + Multiaddr, +) import trio -from multiaddr import Multiaddr -from libp2p.peer.id import ID + +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.host.basic_host import ( + BasicHost, +) +from libp2p.network.swarm import ( + Swarm, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.peer.peerstore import ( + PeerStore, +) +from libp2p.security.noise.transport import ( + PROTOCOL_ID, +) from libp2p.security.noise.transport import Transport as NoiseTransport -from libp2p.transport.tcp.tcp import TCP -from libp2p.transport.upgrader import TransportUpgrader -from libp2p.stream_muxer.mplex.mplex import Mplex -from libp2p.host.basic_host import BasicHost -from libp2p.network.swarm import Swarm -from libp2p.crypto.ed25519 import create_new_key_pair -from libp2p.security.noise.transport import PROTOCOL_ID -from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID -from libp2p.peer.peerstore import PeerStore -from .webrtc import WebRTCTransport +from libp2p.stream_muxer.mplex.mplex import ( + MPLEX_PROTOCOL_ID, + Mplex, +) +from libp2p.transport.tcp.tcp import ( + TCP, +) +from libp2p.transport.upgrader import ( + TransportUpgrader, +) + +from .webrtc import ( + WebRTCTransport, +) logging.basicConfig(level=logging.INFO) logger = logging.getLogger("webrtc-loopback-test") -async def build_host(name): + +async def build_host(name: str) -> tuple[BasicHost, ID]: key_pair = create_new_key_pair() peer_id = ID.from_pubkey(key_pair.public_key) logger.info(f"Peer {name} ID: {peer_id}") base_transport = TCP() - peer_store= PeerStore() + peer_store = PeerStore() secure_transports = { - PROTOCOL_ID: NoiseTransport(libp2p_keypair=key_pair), + PROTOCOL_ID: NoiseTransport(libp2p_keypair=key_pair), } - muxer_transports = { - MPLEX_PROTOCOL_ID: Mplex - } + muxer_transports = {MPLEX_PROTOCOL_ID: Mplex} upgrader = TransportUpgrader(secure_transports, muxer_transports) swarm = Swarm(peer_id, peer_store, upgrader, base_transport) host = BasicHost(swarm) - await host.get_network().listen(Multiaddr(f"/ip4/127.0.0.1/tcp/9095/ws/p2p/{peer_id}")) - logger.info(f"Host {name} listening on: {host.get_network().listening_addrs()}") + await host.get_network().listen( + Multiaddr(f"/ip4/127.0.0.1/tcp/9095/ws/p2p/{peer_id}") + ) + logger.info(f"Host {name} listening on: {host.get_network().listen(peer_id)}") return host, peer_id -async def run_loopback_test(): + +async def run_loopback_test() -> None: host_b, peer_id_b = await build_host("webrtc") webrtc_b = WebRTCTransport(peer_id=peer_id_b, host=host_b) + print(f"[*] Server Peer ID: {peer_id_b}") + print(f"[*] Listening Addrs: {host_b.get_network().listen(peer_id_b)}") - listener = await webrtc_b.create_listener(lambda conn: logger.info("[B] Listener ready")) + listener = await webrtc_b.create_listener( + lambda conn: logger.info("[B] Listener ready") + ) - async def act_as_server(): + async def act_as_server() -> None: conn = await listener.accept() msg = await conn.read() logger.info(f"[B] Got message: {msg.decode()}") await conn.write(b"Reply from B") - async def act_as_client(): + async def act_as_client() -> None: host_a, peer_id_a = await build_host("webrtc") webrtc_a = WebRTCTransport(peer_id=peer_id_a, host=host_a) + print(f"[*] Client Peer ID: {peer_id_a}") + print(f"[*] Listening Addrs: {host_a.get_network().listening_addrs()}") - conn = await webrtc_a.dial(Multiaddr("/webrtc")) + conn = await webrtc_a.dial( + Multiaddr(f"/ip4/127.0.0.1/tcp/9095/ws/p2p/{peer_id_b}/p2p-circuit/webrtc") + ) logger.info("[A] Dial successful") await conn.write(b"Hello from A") @@ -64,9 +99,9 @@ async def act_as_client(): async with trio.open_nursery() as nursery: nursery.start_soon(act_as_server) - await trio.sleep(1) + await trio.sleep(1) nursery.start_soon(act_as_client) + if __name__ == "__main__": trio.run(run_loopback_test) - diff --git a/libp2p/transport/webrtc/webrtc.py b/libp2p/transport/webrtc/webrtc.py index e2f0fb524..40ea1ca62 100644 --- a/libp2p/transport/webrtc/webrtc.py +++ b/libp2p/transport/webrtc/webrtc.py @@ -1,26 +1,55 @@ -import trio +from collections.abc import ( + Coroutine, +) import json import logging -from aiortc import RTCIceCandidate, RTCPeerConnection, RTCSessionDescription, RTCDataChannel -from libp2p.pubsub.gossipsub import GossipSub -from libp2p.pubsub.pubsub import Pubsub -from libp2p.abc import ITransport, IRawConnection, IListener -from typing import Callable -from libp2p.peer.id import ID -from libp2p.host.basic_host import BasicHost -from multiaddr import Multiaddr +from typing import ( + Any, + Callable, + Optional, +) + +from aiortc import ( + RTCDataChannel, + RTCIceCandidate, + RTCPeerConnection, + RTCSessionDescription, +) +from multiaddr import ( + Multiaddr, +) +import trio + +from libp2p.abc import ( + IListener, + IRawConnection, + ITransport, +) +from libp2p.host.basic_host import ( + BasicHost, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.pubsub.gossipsub import ( + GossipSub, +) +from libp2p.pubsub.pubsub import ( + Pubsub, +) logger = logging.getLogger("webrtc") logging.basicConfig(level=logging.INFO) SIGNAL_PROTOCOL = "/libp2p/webrtc/signal/1.0.0" + class WebRTCRawConnection(IRawConnection): def __init__(self, channel: RTCDataChannel): self.channel = channel self.receive_channel, self.send_channel = trio.open_memory_channel(0) @channel.on("message") - def on_message(message): + def on_message(message: Any) -> None: self.send_channel.send_nowait(message) async def read(self, n: int = -1) -> bytes: @@ -32,6 +61,7 @@ async def write(self, data: bytes) -> None: async def close(self) -> None: self.channel.close() + class WebRTCListener(IListener): def __init__(self, host, peer_id: ID): self.host = host @@ -45,14 +75,14 @@ async def accept(self) -> IRawConnection: return await self.conn_recv_channel.receive() async def close(self) -> None: - pass + pass async def _handle_stream(self, stream) -> None: pc = RTCPeerConnection() channel_ready = trio.Event() @pc.on("datachannel") - def on_datachannel(channel): + def on_datachannel(channel: RTCDataChannel) -> None: @channel.on("open") def opened(): channel_ready.set() @@ -67,58 +97,57 @@ def opened(): answer = await pc.createAnswer() await pc.setLocalDescription(answer) - await stream.write(json.dumps({ - "sdp": answer.sdp, - "type": answer.type - }).encode()) + await stream.write( + json.dumps({"sdp": answer.sdp, "type": answer.type}).encode() + ) await channel_ready.wait() class WebRTCTransport(ITransport): - def __init__(self, peer_id, host: BasicHost, config=None): + def __init__(self, peer_id, host: BasicHost, config: Optional[dict] = None): self.peer_id = peer_id self.host = host - self.config = config or {} + self.config = config or {} self.peer_connection = RTCPeerConnection() - self.data_channel = None - self.pubsub = None - - async def create_data_channel(self): + self.data_channel: Optional[RTCDataChannel] = None + self.pubsub: Optional[Pubsub] = None + + async def create_data_channel(self) -> None: """Creates and opens a WebRTC data channel""" self.data_channel = self.peer_connection.createDataChannel("libp2p-webrtc") @self.data_channel.on("open") - def on_open(): + def on_open() -> None: logger.info(f"Data channel open with peer {self.peer_id}") @self.data_channel.on("message") - def on_message(message): + def on_message(message: Any) -> None: logger.info(f"Received message from peer {self.peer_id}: {message}") - async def handle_offer_from_peer(self, stream, data): - """Handle offer and send back answer on same stream""" - offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) - await self.peer_connection.setRemoteDescription(offer) + async def handle_offer_from_peer(self, stream: Any, data: dict) -> None: + """Handle offer and send back answer on same stream""" + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await self.peer_connection.setRemoteDescription(offer) - answer = await self.peer_connection.createAnswer() - await self.peer_connection.setLocalDescription(answer) + answer = await self.peer_connection.createAnswer() + await self.peer_connection.setLocalDescription(answer) - response = { - "type": "answer", - "sdp": answer.sdp, - "sdpType": answer.type, - "peer_id": self.peer_id - } + response = { + "type": "answer", + "sdp": answer.sdp, + "sdpType": answer.type, + "peer_id": self.peer_id, + } - await stream.write(json.dumps(response).encode()) + await stream.write(json.dumps(response).encode()) - async def handle_answer_from_peer(self, data): + async def handle_answer_from_peer(self, data: dict) -> None: """Handle SDP answer from peer""" answer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) await self.peer_connection.setRemoteDescription(answer) - async def handle_ice_candidate(self, data): + async def handle_ice_candidate(self, data: dict) -> None: """Optional: ICE candidate support""" candidate = RTCIceCandidate( component=data["component"], @@ -127,18 +156,18 @@ async def handle_ice_candidate(self, data): ip=data["ip"], protocol=data["protocol"], port=data["port"], - type=data["candidateType"] + type=data["candidateType"], ) await self.peer_connection.addIceCandidate(candidate) - async def start_peer_discovery(self): + async def start_peer_discovery(self) -> None: """Starts peer discovery using GossipSub""" gossipsub = GossipSub() self.pubsub = Pubsub(self.host, gossipsub, None) topic = await self.pubsub.subscribe("webrtc-peer-discovery") - async def handle_message(): + async def handle_message() -> None: async for msg in topic: logger.info(f"Discovered Peer: {msg.data.decode()}") @@ -147,9 +176,15 @@ async def handle_message(): await self.pubsub.publish("webrtc-peer-discovery", self.peer_id.encode()) - async def create_listener(self, handler: Callable[[IRawConnection], None]) -> IListener: + async def create_listener( + self, handler: Callable[[IRawConnection], Coroutine[Any, Any, None]] + ) -> IListener: listener = WebRTCListener(self.host, self.peer_id) - await listener.listen(Multiaddr(f"/ip4/147.28.186.157/tcp/9095/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit/webrtc")) + await listener.listen( + Multiaddr( + f"/ip4/147.28.186.157/tcp/9095/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit/webrtc" + ) + ) self.host.set_stream_handler(SIGNAL_PROTOCOL, listener._handle_stream) return listener @@ -163,18 +198,22 @@ async def dial(self, maddr: Multiaddr) -> IRawConnection: channel_ready = trio.Event() @channel.on("open") - def on_open(): + def on_open() -> None: channel_ready.set() offer = await pc.createOffer() await pc.setLocalDescription(offer) - await stream.write(json.dumps({ - "type": "offer", - "peer_id": self.peer_id, - "sdp": offer.sdp, - "sdpType": offer.type, - }).encode()) + await stream.write( + json.dumps( + { + "type": "offer", + "peer_id": self.peer_id, + "sdp": offer.sdp, + "sdpType": offer.type, + } + ).encode() + ) answer_data = await stream.read() answer_msg = json.loads(answer_data.decode()) @@ -184,7 +223,9 @@ def on_open(): await channel_ready.wait() return WebRTCRawConnection(channel) -def parse_webrtc_multiaddr(multiaddr_str): + + +def parse_webrtc_multiaddr(multiaddr_str: str) -> Multiaddr: """Parse and validate a WebRTC multiaddr.""" try: addr = Multiaddr(multiaddr_str) @@ -194,5 +235,3 @@ def parse_webrtc_multiaddr(multiaddr_str): except Exception as e: logger.error(f"Failed to parse multiaddr: {e}") return None - - diff --git a/libp2p/transport/webrtc/webrtc_signal_protocol.py b/libp2p/transport/webrtc/webrtc_signal_protocol.py index ebfe93578..2ceb731d7 100644 --- a/libp2p/transport/webrtc/webrtc_signal_protocol.py +++ b/libp2p/transport/webrtc/webrtc_signal_protocol.py @@ -1,17 +1,18 @@ -# webrtc_signal_protocol.py -# libp2p stream handler that dispatches incoming messages to the WebRTCTransport. - import json import logging +from typing import ( + Any, +) logger = logging.getLogger("signal-protocol") PROTOCOL_ID = "/libp2p/webrtc/signal/1.0.0" + class WebRTCSignalingProtocol: - def __init__(self, transport): - self.transport = transport # Reference to WebRTCTransport + def __init__(self, transport: Any): + self.transport = transport - async def handle_stream(self, stream): + async def handle_stream(self, stream: Any) -> None: """Handle incoming signaling messages on a libp2p stream""" try: while True: @@ -30,4 +31,3 @@ async def handle_stream(self, stream): await self.transport.handle_ice_candidate(message) except Exception as e: logger.error(f"Error handling signaling stream: {e}") - From 95e0ec970d8d69e8cbb2155a9145202acc74f39d Mon Sep 17 00:00:00 2001 From: Neha Kumari Date: Sat, 19 Apr 2025 01:44:23 +0530 Subject: [PATCH 4/9] refactor-webrtc-setup-and-loopback-test --- libp2p/transport/webrtc/connection.py | 82 ++++++++ libp2p/transport/webrtc/listener.py | 103 ++++++++++ libp2p/transport/webrtc/test_loopback.py | 146 ++++++++++---- libp2p/transport/webrtc/utils.py | 52 +++++ libp2p/transport/webrtc/webrtc.py | 182 +++++++----------- .../webrtc/webrtc_signal_protocol.py | 33 ---- 6 files changed, 417 insertions(+), 181 deletions(-) create mode 100644 libp2p/transport/webrtc/connection.py create mode 100644 libp2p/transport/webrtc/listener.py create mode 100644 libp2p/transport/webrtc/utils.py delete mode 100644 libp2p/transport/webrtc/webrtc_signal_protocol.py diff --git a/libp2p/transport/webrtc/connection.py b/libp2p/transport/webrtc/connection.py new file mode 100644 index 000000000..f44bc7b35 --- /dev/null +++ b/libp2p/transport/webrtc/connection.py @@ -0,0 +1,82 @@ +import logging +from typing import ( + Any, + Tuple, +) +import trio + +from libp2p.peer.id import ( + ID, +) +from aiortc import ( + RTCDataChannel +) +from trio import ( + MemoryReceiveChannel, + MemorySendChannel +) +import trio + +from libp2p.abc import ( + ISecureConn, +) +from libp2p.stream_muxer.mplex.mplex import ( + Mplex +) + +logger = logging.getLogger("webrtc") +logging.basicConfig(level=logging.INFO) + + +class WebRTCRawConnection(ISecureConn): + def __init__(self, peer_id: ID, channel: RTCDataChannel): + self.peer_id = peer_id + self.channel = channel + self.send_channel: MemorySendChannel[Any] + self.receive_channel: MemoryReceiveChannel[Any] + self.send_channel, self.receive_channel = trio.open_memory_channel(0) + + @channel.on("message") + def on_message(message: Any) -> None: + self.send_channel.send_nowait(message) + + self.mplex = Mplex(self, self.peer_id) + + def _send_func(self, data: bytes) -> None: + self.channel.send(data) + + async def _recv_func(self) -> bytes: + return await self.receive_channel.receive() + + async def open_stream(self) -> Any: + return await self.mplex.open_stream() + + async def accept_stream(self) -> Any: + return await self.mplex.accept_stream() + + async def read(self, n: int = -1) -> bytes: + return await self.receive_channel.receive() + + async def write(self, data: bytes) -> None: + self.channel.send(data) + + def get_remote_address(self) -> Tuple[str, int] | None: + return self.get_remote_address() + + async def close(self) -> None: + self.channel.close() + await self.send_channel.aclose() + await self.receive_channel.aclose() + await self.mplex.close() + + def get_local_peer(self) -> ID: + return self.get_local_peer() + + def get_local_private_key(self) -> Any: + return self.get_local_private_key() + + def get_remote_peer(self) -> ID: + return self.get_remote_peer() + + def get_remote_public_key(self) -> Any: + return self.get_remote_public_key() diff --git a/libp2p/transport/webrtc/listener.py b/libp2p/transport/webrtc/listener.py new file mode 100644 index 000000000..bbe749dc7 --- /dev/null +++ b/libp2p/transport/webrtc/listener.py @@ -0,0 +1,103 @@ +import json +import logging +from typing import ( + Any, + Tuple, +) +from aiortc import ( + RTCDataChannel, + RTCPeerConnection, + RTCSessionDescription, +) +from multiaddr import ( + Multiaddr, +) +from trio import ( + Nursery, + Event, + MemoryReceiveChannel, + MemorySendChannel +) +import trio + +from libp2p.abc import ( + IListener, + TProtocol +) +from libp2p.host.basic_host import ( + BasicHost, +) +from libp2p.peer.id import ( + ID, +) +from .connection import ( + WebRTCRawConnection +) + +logger = logging.getLogger("webrtc") +logging.basicConfig(level=logging.INFO) +SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") + + +class WebRTCListener(IListener): + def __init__(self, host: BasicHost, peer_id: ID): + self.host = host + self.peer_id = peer_id + self.conn_send_channel: MemorySendChannel[WebRTCRawConnection] + self.conn_receive_channel: MemoryReceiveChannel[WebRTCRawConnection] + self.conn_send_channel, self.conn_receive_channel = trio.open_memory_channel(0) + + async def listen(self, maddr: Multiaddr, nursery: Nursery) -> bool: + self.host.set_stream_handler(SIGNAL_PROTOCOL, lambda stream: nursery.start_soon(self._handle_stream, stream)) + await self.host.get_network().listen(maddr) + return True + + async def accept(self) -> WebRTCRawConnection: + return await self.conn_receive_channel.receive() + + async def close(self) -> None: + await self.conn_send_channel.aclose() + await self.conn_receive_channel.aclose() + + async def _handle_stream(self, stream: Any) -> None: + pc = RTCPeerConnection() + channel_ready = Event() + + @pc.on("datachannel") + def on_datachannel(channel: RTCDataChannel) -> None: + @channel.on("open") + def opened() -> None: + channel_ready.set() + self.conn_send_channel.send_nowait(WebRTCRawConnection(self.peer_id, channel)) + + @pc.on("icecandidate") + def on_ice_candidate(candidate: Any) -> None: + if candidate: + msg = { + "type": "ice", + "candidateType": candidate.type, + "component": candidate.component, + "foundation": candidate.foundation, + "priority": candidate.priority, + "ip": candidate.address, + "port": candidate.port, + "protocol": candidate.protocol, + } + trio.lowlevel.spawn_system_task(stream.write, json.dumps(msg).encode()) + + offer_data = await stream.read() + offer_msg = json.loads(offer_data.decode()) + offer = RTCSessionDescription(**offer_msg) + await pc.setRemoteDescription(offer) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + await stream.write(json.dumps({"sdp": answer.sdp, "type": answer.type}).encode()) + await channel_ready.wait() + await stream.close() + + def get_addrs(self) -> Tuple[Multiaddr, ...]: + return ( + Multiaddr(f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{self.peer_id}/p2p-circuit/webrtc"), + ) diff --git a/libp2p/transport/webrtc/test_loopback.py b/libp2p/transport/webrtc/test_loopback.py index 72122d7d0..e704e29c4 100644 --- a/libp2p/transport/webrtc/test_loopback.py +++ b/libp2p/transport/webrtc/test_loopback.py @@ -1,9 +1,17 @@ +import json import logging from multiaddr import ( Multiaddr, ) +from multiaddr.protocols import ( + Protocol, + add_protocol, +) import trio +from trio import ( + Nursery, +) from libp2p.crypto.ed25519 import ( create_new_key_pair, @@ -28,14 +36,18 @@ MPLEX_PROTOCOL_ID, Mplex, ) -from libp2p.transport.tcp.tcp import ( - TCP, -) from libp2p.transport.upgrader import ( TransportUpgrader, ) +from .connection import ( + WebRTCRawConnection, +) +from .listener import ( + WebRTCListener, +) from .webrtc import ( + SIGNAL_PROTOCOL, WebRTCTransport, ) @@ -43,65 +55,119 @@ logger = logging.getLogger("webrtc-loopback-test") -async def build_host(name: str) -> tuple[BasicHost, ID]: +async def build_host(name: str) -> tuple[BasicHost, ID, WebRTCTransport]: key_pair = create_new_key_pair() peer_id = ID.from_pubkey(key_pair.public_key) - logger.info(f"Peer {name} ID: {peer_id}") + logger.info(f"[{name}] Peer ID: {peer_id}") - base_transport = TCP() + webrtc_transport = WebRTCTransport(peer_id=peer_id, host=None) peer_store = PeerStore() - secure_transports = { - PROTOCOL_ID: NoiseTransport(libp2p_keypair=key_pair), - } - muxer_transports = {MPLEX_PROTOCOL_ID: Mplex} + secure_transports = {PROTOCOL_ID: NoiseTransport(libp2p_keypair=key_pair)} + muxer_transports = {MPLEX_PROTOCOL_ID: Mplex} upgrader = TransportUpgrader(secure_transports, muxer_transports) - swarm = Swarm(peer_id, peer_store, upgrader, base_transport) + + swarm = Swarm(peer_id, peer_store, upgrader, webrtc_transport) host = BasicHost(swarm) + webrtc_transport.host = host + + return host, peer_id, webrtc_transport + - await host.get_network().listen( - Multiaddr(f"/ip4/127.0.0.1/tcp/9095/ws/p2p/{peer_id}") - ) - logger.info(f"Host {name} listening on: {host.get_network().listen(peer_id)}") - return host, peer_id +async def run_loopback_test(nursery: Nursery) -> None: + host_b, peer_id_b, webrtc_transport_b = await build_host("Server") + logger.info(f"[B] Peer ID: {peer_id_b}") + logger.info(f"[B] Listening Addrs: {host_b.get_connected_peers()}") + webrtc_proto = Protocol(name="webrtc", code=277, codec=None) + add_protocol(webrtc_proto) + webrtc_conn = webrtc_transport_b.create_listener(WebRTCListener) + # await webrtc_listener.listen( + # Multiaddr(f"/ip4/127.0.0.1/tcp/9095/ws/p2p/{peer_id_b}/p2p-circuit/webrtc"), + # nursery, + # ) -async def run_loopback_test() -> None: - host_b, peer_id_b = await build_host("webrtc") - webrtc_b = WebRTCTransport(peer_id=peer_id_b, host=host_b) - print(f"[*] Server Peer ID: {peer_id_b}") - print(f"[*] Listening Addrs: {host_b.get_network().listen(peer_id_b)}") + # for addr in webrtc_listener.: + # logger.info(f"[B] Listening on: {addr}") - listener = await webrtc_b.create_listener( - lambda conn: logger.info("[B] Listener ready") - ) + logger.info("[B] Listening WebRTC setup complete.") async def act_as_server() -> None: - conn = await listener.accept() - msg = await conn.read() - logger.info(f"[B] Got message: {msg.decode()}") - await conn.write(b"Reply from B") + try: + logger.info("[B] Waiting to accept connection...") + conn: WebRTCRawConnection = await webrtc_conn + logger.info("[B] Connection accepted.") + + stream = await host_b.new_stream(conn.peer_id, [SIGNAL_PROTOCOL]) + offer_data = await stream.read() + offer_json = json.loads(offer_data.decode()) + + await webrtc_transport_b.handle_offer_from_peer(stream, offer_json) + answer = await webrtc_transport_b.peer_connection.createAnswer() + await webrtc_transport_b.peer_connection.setLocalDescription(answer) + + await stream.write( + json.dumps({"sdp": answer.sdp, "sdpType": answer.type}).encode() + ) + + await trio.sleep(1) + + if webrtc_transport_b.data_channel is not None: + raw_conn = WebRTCRawConnection( + peer_id_b, webrtc_transport_b.data_channel + ) + msg = await raw_conn.read() + logger.info(f"[B] Received: {msg.decode()}") + await raw_conn.write(b"Reply from B") + await raw_conn.close() + else: + logger.error("[B] Data channel not established!") + except Exception as e: + logger.error(f"[B] Error in act_as_server: {e}") async def act_as_client() -> None: - host_a, peer_id_a = await build_host("webrtc") - webrtc_a = WebRTCTransport(peer_id=peer_id_a, host=host_a) - print(f"[*] Client Peer ID: {peer_id_a}") - print(f"[*] Listening Addrs: {host_a.get_network().listening_addrs()}") + host_a, peer_id_a, webrtc_client = await build_host("Client") + await webrtc_client.create_data_channel() - conn = await webrtc_a.dial( - Multiaddr(f"/ip4/127.0.0.1/tcp/9095/ws/p2p/{peer_id_b}/p2p-circuit/webrtc") + maddr = Multiaddr( + f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{peer_id_b}/p2p-circuit/webrtc" ) - logger.info("[A] Dial successful") + host_a.get_network().peerstore.add_addr(peer_id_b, maddr, 3000) + logger.info(f"[A] Peerstore updated with address: {maddr}") - await conn.write(b"Hello from A") - response = await conn.read() - logger.info(f"[A] Got response: {response.decode()}") + stream = await host_a.new_stream(peer_id_b, [SIGNAL_PROTOCOL]) + + offer = await webrtc_client.peer_connection.createOffer() + await webrtc_client.peer_connection.setLocalDescription(offer) + await stream.write( + json.dumps({"sdp": offer.sdp, "sdpType": offer.type}).encode() + ) + + answer_data = await stream.read() + answer_json = json.loads(answer_data.decode()) + await webrtc_client.handle_answer_from_peer(answer_json) + + await trio.sleep(1) + + if webrtc_client.data_channel is not None: + conn = WebRTCRawConnection(peer_id_b, webrtc_client.data_channel) + await conn.write(b"Hello from A") + reply = await conn.read() + logger.info(f"[A] Received: {reply.decode()}") + await conn.close() + else: + logger.error("[A] Data channel not established!") async with trio.open_nursery() as nursery: nursery.start_soon(act_as_server) - await trio.sleep(1) + await trio.sleep(1.5) nursery.start_soon(act_as_client) +async def run_loopback_main() -> None: + async with trio.open_nursery() as nursery: + await run_loopback_test(nursery) + + if __name__ == "__main__": - trio.run(run_loopback_test) + trio.run(lambda: run_loopback_main()) diff --git a/libp2p/transport/webrtc/utils.py b/libp2p/transport/webrtc/utils.py new file mode 100644 index 000000000..3599964ce --- /dev/null +++ b/libp2p/transport/webrtc/utils.py @@ -0,0 +1,52 @@ +from multiaddr import ( + Multiaddr, +) +import logging + +logger = logging.getLogger("webrtc") +logging.basicConfig(level=logging.INFO) + +def parse_webrtc_multiaddr(multiaddr_str: str) -> Multiaddr: + """ + Parse, validate, and extract components from a WebRTC multiaddr. + + Expected format: + /ip4|dns4|dns6/
/tcp//p2p//p2p-circuit/webrtc + """ + try: + addr = Multiaddr(multiaddr_str) + protocols = [p.name for p in addr.protocols()] + + if "webrtc" not in protocols: + raise ValueError("Missing /webrtc protocol in multiaddr") + + if "p2p" not in protocols: + raise ValueError("Missing /p2p protocol (peer ID required)") + + # Extracting peer ID and address components + components = addr.items() + peer_id = None + ip_or_dns = None + port = None + + for proto, value in components: + if proto in ("ip4", "ip6", "dns4", "dns6"): + ip_or_dns = value + elif proto == "tcp": + port = value + elif proto == "p2p": + peer_id = value + + if not all([ip_or_dns, port, peer_id]): + raise ValueError("Incomplete multiaddr: Must include IP/DNS, TCP port, and Peer ID") + + return { + "multiaddr": addr, + "peer_id": peer_id, + "ip_or_dns": ip_or_dns, + "port": port + } + + except Exception as e: + logger.error(f"[parse_webrtc_multiaddr] Failed to parse multiaddr: {e}") + return None diff --git a/libp2p/transport/webrtc/webrtc.py b/libp2p/transport/webrtc/webrtc.py index 40ea1ca62..83154371c 100644 --- a/libp2p/transport/webrtc/webrtc.py +++ b/libp2p/transport/webrtc/webrtc.py @@ -1,6 +1,3 @@ -from collections.abc import ( - Coroutine, -) import json import logging from typing import ( @@ -9,6 +6,9 @@ Optional, ) +from _collections_abc import ( + Coroutine, +) from aiortc import ( RTCDataChannel, RTCIceCandidate, @@ -22,8 +22,9 @@ from libp2p.abc import ( IListener, - IRawConnection, + ISecureConn, ITransport, + TProtocol, ) from libp2p.host.basic_host import ( BasicHost, @@ -38,74 +39,25 @@ Pubsub, ) +from .connection import ( + WebRTCRawConnection, +) +from .listener import ( + WebRTCListener, +) +from .utils import ( + parse_webrtc_multiaddr, +) + logger = logging.getLogger("webrtc") logging.basicConfig(level=logging.INFO) -SIGNAL_PROTOCOL = "/libp2p/webrtc/signal/1.0.0" - - -class WebRTCRawConnection(IRawConnection): - def __init__(self, channel: RTCDataChannel): - self.channel = channel - self.receive_channel, self.send_channel = trio.open_memory_channel(0) - - @channel.on("message") - def on_message(message: Any) -> None: - self.send_channel.send_nowait(message) - - async def read(self, n: int = -1) -> bytes: - return await self.receive_channel.receive() - - async def write(self, data: bytes) -> None: - self.channel.send(data) - - async def close(self) -> None: - self.channel.close() - - -class WebRTCListener(IListener): - def __init__(self, host, peer_id: ID): - self.host = host - self.peer_id = peer_id - self.conn_send_channel, self.conn_recv_channel = trio.open_memory_channel(0) - - async def listen(self, maddr: Multiaddr) -> None: - await self.host.set_stream_handler(SIGNAL_PROTOCOL, self._handle_stream) - - async def accept(self) -> IRawConnection: - return await self.conn_recv_channel.receive() - - async def close(self) -> None: - pass - - async def _handle_stream(self, stream) -> None: - pc = RTCPeerConnection() - channel_ready = trio.Event() - - @pc.on("datachannel") - def on_datachannel(channel: RTCDataChannel) -> None: - @channel.on("open") - def opened(): - channel_ready.set() - - self.conn_send_channel.send_nowait(WebRTCRawConnection(channel)) - - offer_data = await stream.read() - offer_msg = json.loads(offer_data.decode()) - offer = RTCSessionDescription(**offer_msg) - await pc.setRemoteDescription(offer) - - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - - await stream.write( - json.dumps({"sdp": answer.sdp, "type": answer.type}).encode() - ) - - await channel_ready.wait() +SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") class WebRTCTransport(ITransport): - def __init__(self, peer_id, host: BasicHost, config: Optional[dict] = None): + def __init__( + self, peer_id: ID, host: BasicHost, config: Optional[dict[str, Any]] = None + ): self.peer_id = peer_id self.host = host self.config = config or {} @@ -114,7 +66,6 @@ def __init__(self, peer_id, host: BasicHost, config: Optional[dict] = None): self.pubsub: Optional[Pubsub] = None async def create_data_channel(self) -> None: - """Creates and opens a WebRTC data channel""" self.data_channel = self.peer_connection.createDataChannel("libp2p-webrtc") @self.data_channel.on("open") @@ -125,30 +76,27 @@ def on_open() -> None: def on_message(message: Any) -> None: logger.info(f"Received message from peer {self.peer_id}: {message}") - async def handle_offer_from_peer(self, stream: Any, data: dict) -> None: - """Handle offer and send back answer on same stream""" + async def handle_offer_from_peer(self, stream: Any, data: dict[str, Any]) -> None: offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) await self.peer_connection.setRemoteDescription(offer) answer = await self.peer_connection.createAnswer() await self.peer_connection.setLocalDescription(answer) - response = { + response: dict[str, Any] = { "type": "answer", "sdp": answer.sdp, "sdpType": answer.type, - "peer_id": self.peer_id, + "peer_id": str(self.peer_id), } await stream.write(json.dumps(response).encode()) - async def handle_answer_from_peer(self, data: dict) -> None: - """Handle SDP answer from peer""" + async def handle_answer_from_peer(self, data: dict[str, Any]) -> None: answer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) await self.peer_connection.setRemoteDescription(answer) - async def handle_ice_candidate(self, data: dict) -> None: - """Optional: ICE candidate support""" + async def handle_ice_candidate(self, data: dict[str, Any]) -> None: candidate = RTCIceCandidate( component=data["component"], foundation=data["foundation"], @@ -160,9 +108,30 @@ async def handle_ice_candidate(self, data: dict) -> None: ) await self.peer_connection.addIceCandidate(candidate) + async def handle_incoming_candidates( + self, stream: Any, peer_connection: RTCPeerConnection + ) -> None: + while True: + try: + raw = await stream.read() + data: dict[str, Any] = json.loads(raw.decode()) + if data.get("type") == "ice": + candidate = RTCIceCandidate( + component=data["component"], + foundation=data["foundation"], + priority=data["priority"], + ip=data["ip"], + protocol=data["protocol"], + port=data["port"], + type=data["candidateType"], + ) + await peer_connection.addIceCandidate(candidate) + except Exception as e: + logger.error(f"[ICE Trickling] Error reading ICE candidate: {e}") + break + async def start_peer_discovery(self) -> None: - """Starts peer discovery using GossipSub""" - gossipsub = GossipSub() + gossipsub = GossipSub(protocols=[], degree=10, degree_low=3, degree_high=15) self.pubsub = Pubsub(self.host, gossipsub, None) topic = await self.pubsub.subscribe("webrtc-peer-discovery") @@ -174,27 +143,31 @@ async def handle_message() -> None: async with trio.open_nursery() as nursery: nursery.start_soon(handle_message) - await self.pubsub.publish("webrtc-peer-discovery", self.peer_id.encode()) + await self.pubsub.publish("webrtc-peer-discovery", str(self.peer_id).encode()) - async def create_listener( - self, handler: Callable[[IRawConnection], Coroutine[Any, Any, None]] - ) -> IListener: - listener = WebRTCListener(self.host, self.peer_id) - await listener.listen( - Multiaddr( - f"/ip4/147.28.186.157/tcp/9095/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit/webrtc" - ) - ) - self.host.set_stream_handler(SIGNAL_PROTOCOL, listener._handle_stream) - return listener - - async def dial(self, maddr: Multiaddr) -> IRawConnection: + async def dial(self, maddr: Multiaddr) -> ISecureConn: peer_id = parse_webrtc_multiaddr(maddr) stream = await self.host.new_stream(peer_id, [SIGNAL_PROTOCOL]) pc = RTCPeerConnection() - channel = pc.createDataChannel("libp2p") + @pc.on("icecandidate") + def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: + if candidate: + msg = { + "type": "ice", + "candidateType": candidate.type, + "component": candidate.component, + "foundation": candidate.foundation, + "priority": candidate.priority, + "port": candidate.port, + "protocol": candidate.protocol, + } + trio.lowlevel.spawn_system_task(stream.write, json.dumps(msg).encode()) + + trio.lowlevel.spawn_system_task(self.handle_incoming_candidates, stream, pc) + + channel = pc.createDataChannel("libp2p") channel_ready = trio.Event() @channel.on("open") @@ -208,7 +181,7 @@ def on_open() -> None: json.dumps( { "type": "offer", - "peer_id": self.peer_id, + "peer_id": str(self.peer_id), "sdp": offer.sdp, "sdpType": offer.type, } @@ -216,22 +189,15 @@ def on_open() -> None: ) answer_data = await stream.read() - answer_msg = json.loads(answer_data.decode()) + answer_msg: dict[str, Any] = json.loads(answer_data.decode()) answer = RTCSessionDescription(**answer_msg) await pc.setRemoteDescription(answer) await channel_ready.wait() + return WebRTCRawConnection(self.peer_id, channel) - return WebRTCRawConnection(channel) - - -def parse_webrtc_multiaddr(multiaddr_str: str) -> Multiaddr: - """Parse and validate a WebRTC multiaddr.""" - try: - addr = Multiaddr(multiaddr_str) - if "/webrtc" not in [p.name for p in addr.protocols()]: - raise ValueError("Invalid WebRTC multiaddr: Missing /webrtc protocol") - return addr - except Exception as e: - logger.error(f"Failed to parse multiaddr: {e}") - return None + async def create_listener( + self, handler: Callable[[WebRTCListener], Coroutine[Any, Any, None]] + ) -> IListener: + listener = await self.create_listener(handler=handler) + return listener diff --git a/libp2p/transport/webrtc/webrtc_signal_protocol.py b/libp2p/transport/webrtc/webrtc_signal_protocol.py deleted file mode 100644 index 2ceb731d7..000000000 --- a/libp2p/transport/webrtc/webrtc_signal_protocol.py +++ /dev/null @@ -1,33 +0,0 @@ -import json -import logging -from typing import ( - Any, -) - -logger = logging.getLogger("signal-protocol") -PROTOCOL_ID = "/libp2p/webrtc/signal/1.0.0" - - -class WebRTCSignalingProtocol: - def __init__(self, transport: Any): - self.transport = transport - - async def handle_stream(self, stream: Any) -> None: - """Handle incoming signaling messages on a libp2p stream""" - try: - while True: - data = await stream.read() - if not data: - break - - message = json.loads(data.decode()) - - msg_type = message.get("type") - if msg_type == "offer": - await self.transport.handle_offer_from_peer(stream, message) - elif msg_type == "answer": - await self.transport.handle_answer_from_peer(message) - elif msg_type == "ice": - await self.transport.handle_ice_candidate(message) - except Exception as e: - logger.error(f"Error handling signaling stream: {e}") From 690c1f92d31166e4bed26d0127ac291911ce4379 Mon Sep 17 00:00:00 2001 From: Neha Kumari Date: Tue, 13 May 2025 00:38:26 +0530 Subject: [PATCH 5/9] feat(webrtc): revamp transport and signalling service --- libp2p/transport/webrtc/connection.py | 22 +- libp2p/transport/webrtc/gen_certhash.py | 137 +++++++ libp2p/transport/webrtc/listener.py | 101 +++-- libp2p/transport/webrtc/signal_service.py | 83 ++++ libp2p/transport/webrtc/test_loopback.py | 96 ++--- libp2p/transport/webrtc/test_webrtc.py | 119 ++++++ libp2p/transport/webrtc/utils.py | 52 --- libp2p/transport/webrtc/webrtc.py | 467 +++++++++++++++++++--- 8 files changed, 873 insertions(+), 204 deletions(-) create mode 100644 libp2p/transport/webrtc/gen_certhash.py create mode 100644 libp2p/transport/webrtc/signal_service.py create mode 100644 libp2p/transport/webrtc/test_webrtc.py delete mode 100644 libp2p/transport/webrtc/utils.py diff --git a/libp2p/transport/webrtc/connection.py b/libp2p/transport/webrtc/connection.py index f44bc7b35..ec5e35e6b 100644 --- a/libp2p/transport/webrtc/connection.py +++ b/libp2p/transport/webrtc/connection.py @@ -1,27 +1,25 @@ import logging from typing import ( Any, - Tuple, ) -import trio -from libp2p.peer.id import ( - ID, -) from aiortc import ( - RTCDataChannel -) + RTCDataChannel, +) +import trio from trio import ( MemoryReceiveChannel, - MemorySendChannel + MemorySendChannel, ) -import trio from libp2p.abc import ( ISecureConn, ) +from libp2p.peer.id import ( + ID, +) from libp2p.stream_muxer.mplex.mplex import ( - Mplex + Mplex, ) logger = logging.getLogger("webrtc") @@ -34,7 +32,7 @@ def __init__(self, peer_id: ID, channel: RTCDataChannel): self.channel = channel self.send_channel: MemorySendChannel[Any] self.receive_channel: MemoryReceiveChannel[Any] - self.send_channel, self.receive_channel = trio.open_memory_channel(0) + self.send_channel, self.receive_channel = trio.open_memory_channel(50) @channel.on("message") def on_message(message: Any) -> None: @@ -60,7 +58,7 @@ async def read(self, n: int = -1) -> bytes: async def write(self, data: bytes) -> None: self.channel.send(data) - def get_remote_address(self) -> Tuple[str, int] | None: + def get_remote_address(self) -> tuple[str, int] | None: return self.get_remote_address() async def close(self) -> None: diff --git a/libp2p/transport/webrtc/gen_certhash.py b/libp2p/transport/webrtc/gen_certhash.py new file mode 100644 index 000000000..eeb5a99a1 --- /dev/null +++ b/libp2p/transport/webrtc/gen_certhash.py @@ -0,0 +1,137 @@ +import base58 +import hashlib +import ssl +from typing import ( + Optional, + List, + Tuple, +) +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +import hashlib +import base64 +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID +import datetime +from multiaddr import Multiaddr +from typing import Tuple +from aiortc import RTCCertificate +from multiaddr.protocols import ( + Protocol, + add_protocol, +) + +SIGNAL_PROTOCOL = "/libp2p/webrtc/signal/1.0.0" + +class CertificateManager(RTCCertificate): + def __init__(self): + self.private_key = None + self.certificate = None + self.certhash = None + + def generate_self_signed_cert(self, common_name: str = "py-libp2p") -> None: + self.private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048 + ) + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, common_name) + ]) + self.certificate = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(self.private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.utcnow()) + .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) + .sign(self.private_key, hashes.SHA256()) + ) + self.certhash = self._compute_certhash(self.certificate) + + def _compute_certhash(self, cert: x509.Certificate) -> str: + # Encode in DER format and compute SHA-256 hash + der_bytes = cert.public_bytes(serialization.Encoding.DER) + sha256_hash = hashlib.sha256(der_bytes).digest() + return base64.urlsafe_b64encode(sha256_hash).decode("utf-8").rstrip("=") + + def get_certhash(self) -> str: + return self.certhash + + def get_certificate_pem(self) -> bytes: + return self.certificate.public_bytes(serialization.Encoding.PEM) + + def get_private_key_pem(self) -> bytes: + return self.private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + +def parse_webrtc_maddr(maddr: Multiaddr) -> Tuple[str, int, str]: + """ + Parse a WebRTC multiaddr like: + /ip4/127.0.0.1/udp/5000/webrtc/certhash//p2p/ + Returns (ip, port, certhash) + """ + addr = Multiaddr(maddr) + ip = None + port = None + certhash = None + + for c in addr.protocols(): + if c.name == "ip4" or c.name == "ip6": + ip = addr.value_for_protocol(c.name) + elif c.name == "udp": + port = int(addr.value_for_protocol("udp")) + elif c.name == "certhash": + certhash = addr.value_for_protocol("certhash") + + if not ip or not port or not certhash: + raise ValueError("Invalid WebRTC multiaddress") + + return ip, port, certhash + + +def generate_local_certhash(cert_pem: str) -> str: + cert = x509.load_pem_x509_certificate(cert_pem.encode(), default_backend()) + der_bytes = cert.public_bytes(encoding=ssl.Encoding.DER) + digest = hashlib.sha256(der_bytes).digest() + certhash = base58.b58encode(digest).decode() + return f"uEi{certhash}" # js-libp2p compatible + + +def generate_webrtc_multiaddr( + ip: str, peer_id: str, certhash: Optional[str] = None +) -> Multiaddr: + add_protocol(Protocol(291, "webrtc-direct", "webrtc-direct")) + add_protocol(Protocol(292, "certhash", "certhash")) + # certhash = generate_local_certhash() + if not certhash: + raise ValueError("certhash must be provided for /webrtc-direct multiaddr") + + certificate= RTCCertificate.generateCertificate() + base = f"/ip4/{ip}/udp/9000/webrtc-direct/certhash/{certhash}/p2p/{peer_id}" + + return Multiaddr(base) + + +def filter_addresses(addrs: List[Multiaddr]) -> List[Multiaddr]: + """ + Filters the given list of multiaddresses, + returning only those that are valid for WebRTC transport. + + A valid WebRTC multiaddress typically contains /webrtc/ or /webrtc-direct/. + """ + valid_protocols = {"webrtc", "webrtc-direct"} + + def is_valid_webrtc_addr(addr: Multiaddr) -> bool: + try: + protocols = [proto.name for proto in addr.protocols()] + return any(p in valid_protocols for p in protocols) + except Exception: + return False + + return [addr for addr in addrs if is_valid_webrtc_addr(addr)] diff --git a/libp2p/transport/webrtc/listener.py b/libp2p/transport/webrtc/listener.py index bbe749dc7..881f6ad7f 100644 --- a/libp2p/transport/webrtc/listener.py +++ b/libp2p/transport/webrtc/listener.py @@ -2,8 +2,8 @@ import logging from typing import ( Any, - Tuple, ) + from aiortc import ( RTCDataChannel, RTCPeerConnection, @@ -12,43 +12,55 @@ from multiaddr import ( Multiaddr, ) +import trio from trio import ( - Nursery, Event, MemoryReceiveChannel, - MemorySendChannel + MemorySendChannel, ) -import trio from libp2p.abc import ( IListener, - TProtocol + THandler, + TProtocol, ) from libp2p.host.basic_host import ( BasicHost, ) -from libp2p.peer.id import ( - ID, -) + from .connection import ( - WebRTCRawConnection -) + WebRTCRawConnection, +) +from .gen_certhash import ( + CertificateManager, +) logger = logging.getLogger("webrtc") logging.basicConfig(level=logging.INFO) -SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") +SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") class WebRTCListener(IListener): - def __init__(self, host: BasicHost, peer_id: ID): - self.host = host - self.peer_id = peer_id + def __init__(self, handler: THandler): + self._handle_stream = handler + # self.peer_id = peer_id , peer_id: ID + self.host: BasicHost = None self.conn_send_channel: MemorySendChannel[WebRTCRawConnection] self.conn_receive_channel: MemoryReceiveChannel[WebRTCRawConnection] - self.conn_send_channel, self.conn_receive_channel = trio.open_memory_channel(0) + self.conn_send_channel, self.conn_receive_channel = trio.open_memory_channel(5) + self.certificate = str + + def set_host(self, host: BasicHost) -> None: + self.host = host - async def listen(self, maddr: Multiaddr, nursery: Nursery) -> bool: - self.host.set_stream_handler(SIGNAL_PROTOCOL, lambda stream: nursery.start_soon(self._handle_stream, stream)) + async def listen(self, maddr: Multiaddr) -> bool: + if not self.host: + raise RuntimeError("Host is not initialized in WebRTCListener") + + self.host.set_stream_handler( + SIGNAL_PROTOCOL, + self._handle_stream_wrapper, + ) await self.host.get_network().listen(maddr) return True @@ -59,19 +71,33 @@ async def close(self) -> None: await self.conn_send_channel.aclose() await self.conn_receive_channel.aclose() - async def _handle_stream(self, stream: Any) -> None: + async def _handle_stream_wrapper(self, stream: trio.SocketStream) -> None: + try: + await self._handle_stream_logic(stream) + except Exception as e: + logger.exception(f"Error in stream handler: {e}") + finally: + await stream.aclose() + + async def _handle_stream_logic(self, stream: trio.SocketStream) -> None: pc = RTCPeerConnection() channel_ready = Event() @pc.on("datachannel") def on_datachannel(channel: RTCDataChannel) -> None: + logger.info(f"DataChannel received: {channel.label}") + @channel.on("open") - def opened() -> None: + def on_open() -> None: + logger.info("DataChannel opened.") channel_ready.set() - self.conn_send_channel.send_nowait(WebRTCRawConnection(self.peer_id, channel)) + + self.conn_send_channel.send_nowait( + WebRTCRawConnection(self.host.get_id(), channel) + ) @pc.on("icecandidate") - def on_ice_candidate(candidate: Any) -> None: + async def on_ice_candidate(candidate: Any) -> None: if candidate: msg = { "type": "ice", @@ -83,9 +109,12 @@ def on_ice_candidate(candidate: Any) -> None: "port": candidate.port, "protocol": candidate.protocol, } - trio.lowlevel.spawn_system_task(stream.write, json.dumps(msg).encode()) + try: + await stream.send_all(json.dumps(msg).encode()) + except Exception as e: + logger.warning(f"Failed to send ICE candidate: {e}") - offer_data = await stream.read() + offer_data = await stream.receive_some(4096) offer_msg = json.loads(offer_data.decode()) offer = RTCSessionDescription(**offer_msg) await pc.setRemoteDescription(offer) @@ -93,11 +122,25 @@ def on_ice_candidate(candidate: Any) -> None: answer = await pc.createAnswer() await pc.setLocalDescription(answer) - await stream.write(json.dumps({"sdp": answer.sdp, "type": answer.type}).encode()) + await stream.send_all( + json.dumps( + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + ).encode() + ) + await channel_ready.wait() - await stream.close() + await pc.close() - def get_addrs(self) -> Tuple[Multiaddr, ...]: - return ( - Multiaddr(f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{self.peer_id}/p2p-circuit/webrtc"), - ) + def get_addrs(self) -> list[Multiaddr]: + peer_id = self.host.get_id() + certhash = CertificateManager()._compute_certhash(self.certificate.x509) + + base = "/ip4/127.0.0.1/tcp/0" + maddr_str = f"{base}/webrtc-direct/certhash/{certhash}/p2p/{peer_id}" + + try: + maddr = Multiaddr(maddr_str) + return [maddr] + except Exception as e: + logger.error(f"[WebRTCTransport] Failed to create listen Multiaddr: {e}") + return [] diff --git a/libp2p/transport/webrtc/signal_service.py b/libp2p/transport/webrtc/signal_service.py new file mode 100644 index 000000000..c9a0db675 --- /dev/null +++ b/libp2p/transport/webrtc/signal_service.py @@ -0,0 +1,83 @@ +import json +from typing import Callable, Awaitable, Dict, Optional +from multiaddr import Multiaddr +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.peer.id import ID +from libp2p.abc import TProtocol, INotifee, INetStream, IHost +from aiortc import ( + RTCIceCandidate, +) +SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") + +class SignalService(INotifee): + def __init__(self, host: IHost): + self.host = host + self.signal_protocol = SIGNAL_PROTOCOL + self._handlers: dict[str, Callable[[dict, str], Awaitable[None]]] = {} + + def set_handler(self, msg_type: str, handler: Callable[[dict, str], Awaitable[None]]): + self._handlers[msg_type] = handler + + async def listen(self): + await self.host.set_stream_handler(self.signal_protocol, self.handle_signal) + + async def handle_signal(self, stream: INetStream) -> None: + peer_id = stream.muxed_conn.peer_id + reader = stream + + while True: + try: + data = await reader.read(4096) + if not data: + break + msg = json.loads(data.decode()) + msg_type = msg.get("type") + if msg_type in self._handlers: + await self._handlers[msg_type](msg, str(peer_id)) + else: + print(f"No handler for msg type: {msg_type}") + except Exception as e: + print(f"Error in signal handler for {peer_id}: {e}") + break + + async def send_signal(self, peer_id: ID, message: Dict): + try: + stream = await self.host.new_stream(peer_id, [self.signal_protocol]) + await stream.write(json.dumps(message).encode()) + await stream.close() + except Exception as e: + print(f"Failed to send signal to {peer_id}: {e}") + + async def send_offer(self, peer_id: ID, sdp: str, sdp_type: str, certhash: str): + await self.send_signal(peer_id, {"type": "offer", "sdp": sdp, "sdpType": sdp_type, "certhash": certhash}) + + async def send_answer(self, peer_id: ID, sdp: str, sdp_type: str, certhash: str): + await self.send_signal(peer_id, {"type": "answer", "sdp": sdp, "sdpType": sdp_type, "certhash": certhash}) + + async def send_ice_candidate(self, peer_id: ID, candidate: RTCIceCandidate): + await self.send_signal(peer_id, { + "type": "ice", + "candidateType": candidate.type, + "component": candidate.component, + "foundation": candidate.foundation, + "priority": candidate.priority, + "ip": candidate.ip, + "port": candidate.port, + "protocol": candidate.protocol, + "sdpMid": candidate.sdpMid, + }) + + async def connected(self, network, conn): + pass + + async def disconnected(self, network, conn): + pass + + async def opened_stream(self, network, stream): + pass + + async def closed_stream(self, network, stream): + pass + + async def listen_close(self, network, multiaddr): + pass diff --git a/libp2p/transport/webrtc/test_loopback.py b/libp2p/transport/webrtc/test_loopback.py index e704e29c4..b96c421c5 100644 --- a/libp2p/transport/webrtc/test_loopback.py +++ b/libp2p/transport/webrtc/test_loopback.py @@ -13,38 +13,36 @@ Nursery, ) +from libp2p import ( + new_host, +) from libp2p.crypto.ed25519 import ( create_new_key_pair, ) from libp2p.host.basic_host import ( - BasicHost, -) -from libp2p.network.swarm import ( - Swarm, + IHost, ) from libp2p.peer.id import ( ID, ) -from libp2p.peer.peerstore import ( - PeerStore, +from libp2p.peer.peerinfo import ( + info_from_p2p_addr, ) -from libp2p.security.noise.transport import ( - PROTOCOL_ID, +from libp2p.peer.peerstore import ( + PeerInfo, ) -from libp2p.security.noise.transport import Transport as NoiseTransport -from libp2p.stream_muxer.mplex.mplex import ( - MPLEX_PROTOCOL_ID, - Mplex, +from libp2p.pubsub.gossipsub import ( + GossipSub, ) -from libp2p.transport.upgrader import ( - TransportUpgrader, +from libp2p.pubsub.pubsub import ( + Pubsub, ) from .connection import ( WebRTCRawConnection, ) -from .listener import ( - WebRTCListener, +from .gen_certhash import ( + filter_addresses, ) from .webrtc import ( SIGNAL_PROTOCOL, @@ -55,21 +53,20 @@ logger = logging.getLogger("webrtc-loopback-test") -async def build_host(name: str) -> tuple[BasicHost, ID, WebRTCTransport]: +async def build_host(name: str) -> tuple[IHost, ID, WebRTCTransport]: key_pair = create_new_key_pair() peer_id = ID.from_pubkey(key_pair.public_key) + key_pair.private_key + key_pair.public_key logger.info(f"[{name}] Peer ID: {peer_id}") - webrtc_transport = WebRTCTransport(peer_id=peer_id, host=None) - peer_store = PeerStore() - - secure_transports = {PROTOCOL_ID: NoiseTransport(libp2p_keypair=key_pair)} - muxer_transports = {MPLEX_PROTOCOL_ID: Mplex} - upgrader = TransportUpgrader(secure_transports, muxer_transports) - - swarm = Swarm(peer_id, peer_store, upgrader, webrtc_transport) - host = BasicHost(swarm) - webrtc_transport.host = host + host = new_host() + pubsub = Pubsub( + host, + GossipSub(protocols=[SIGNAL_PROTOCOL], degree=10, degree_low=3, degree_high=15), + None, + ) + webrtc_transport = WebRTCTransport(host, pubsub) return host, peer_id, webrtc_transport @@ -77,32 +74,35 @@ async def build_host(name: str) -> tuple[BasicHost, ID, WebRTCTransport]: async def run_loopback_test(nursery: Nursery) -> None: host_b, peer_id_b, webrtc_transport_b = await build_host("Server") logger.info(f"[B] Peer ID: {peer_id_b}") - logger.info(f"[B] Listening Addrs: {host_b.get_connected_peers()}") + addrs = host_b.get_addrs() + PeerInfo(peer_id_b, addrs) + listen = host_b.run(addrs) + logger.info(f"[B] Listening Addrs: {addrs} --- {listen}") + logger.info(f"[B] Active Addrs: {host_b.get_live_peers()}") webrtc_proto = Protocol(name="webrtc", code=277, codec=None) add_protocol(webrtc_proto) - webrtc_conn = webrtc_transport_b.create_listener(WebRTCListener) + # webrtc_conn = await webrtc_transport_b.create_listener(handler_func=) - # await webrtc_listener.listen( + # await webrtc_conn.listen( # Multiaddr(f"/ip4/127.0.0.1/tcp/9095/ws/p2p/{peer_id_b}/p2p-circuit/webrtc"), # nursery, # ) - # for addr in webrtc_listener.: - # logger.info(f"[B] Listening on: {addr}") - logger.info("[B] Listening WebRTC setup complete.") async def act_as_server() -> None: try: logger.info("[B] Waiting to accept connection...") - conn: WebRTCRawConnection = await webrtc_conn - logger.info("[B] Connection accepted.") + active_maddr = host_b.get_addrs() + info = info_from_p2p_addr(active_maddr[0]) + conn = host_b.connect(info) + logger.info(f"[B] Connection accepted. {conn}") - stream = await host_b.new_stream(conn.peer_id, [SIGNAL_PROTOCOL]) + stream = await host_b.new_stream(peer_id_b, [SIGNAL_PROTOCOL]) offer_data = await stream.read() offer_json = json.loads(offer_data.decode()) - await webrtc_transport_b.handle_offer_from_peer(stream, offer_json) + await webrtc_transport_b._handle_signal_message(peer_id_b, offer_json) answer = await webrtc_transport_b.peer_connection.createAnswer() await webrtc_transport_b.peer_connection.setLocalDescription(answer) @@ -127,15 +127,21 @@ async def act_as_server() -> None: async def act_as_client() -> None: host_a, peer_id_a, webrtc_client = await build_host("Client") - await webrtc_client.create_data_channel() - - maddr = Multiaddr( - f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{peer_id_b}/p2p-circuit/webrtc" - ) - host_a.get_network().peerstore.add_addr(peer_id_b, maddr, 3000) + pc = webrtc_client._create_peer_connection() + await webrtc_client.create_data_channel(pc) + + valid = [ + Multiaddr(f"/ip4/127.0.0.1/udp/9095/webrtc/p2p/{peer_id_a}"), + Multiaddr(f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{peer_id_a}/p2p-circuit/webrtc"), + Multiaddr(f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{peer_id_b}/p2p-circuit/webrtc"), + Multiaddr(f"/ip4/127.0.0.1/udp/9095/webrtc/p2p/{peer_id_b}"), + ] + + maddr = filter_addresses(valid) + host_a.get_network().peerstore.add_addr(peer_id_a, maddr, 3000) logger.info(f"[A] Peerstore updated with address: {maddr}") - stream = await host_a.new_stream(peer_id_b, [SIGNAL_PROTOCOL]) + stream = await host_a.new_stream(peer_id_a, [SIGNAL_PROTOCOL]) offer = await webrtc_client.peer_connection.createOffer() await webrtc_client.peer_connection.setLocalDescription(offer) @@ -150,7 +156,7 @@ async def act_as_client() -> None: await trio.sleep(1) if webrtc_client.data_channel is not None: - conn = WebRTCRawConnection(peer_id_b, webrtc_client.data_channel) + conn = WebRTCRawConnection(peer_id_a, webrtc_client.data_channel) await conn.write(b"Hello from A") reply = await conn.read() logger.info(f"[A] Received: {reply.decode()}") diff --git a/libp2p/transport/webrtc/test_webrtc.py b/libp2p/transport/webrtc/test_webrtc.py new file mode 100644 index 000000000..7cb79fa9b --- /dev/null +++ b/libp2p/transport/webrtc/test_webrtc.py @@ -0,0 +1,119 @@ +import trio +import pytest +from libp2p.tools.constants import LISTEN_MADDR +from .webrtc import WebRTCTransport +from libp2p import new_host +from libp2p.pubsub.gossipsub import GossipSub +from libp2p.pubsub.pubsub import Pubsub +from multiaddr import Multiaddr +from libp2p.abc import TProtocol +from multiaddr.protocols import ( + Protocol, + add_protocol, +) +from .connection import WebRTCRawConnection +import logging +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.host.basic_host import ( + IHost, +) +from libp2p.peer.id import ( + ID, +) +from .gen_certhash import CertificateManager, generate_webrtc_multiaddr + +logger = logging.getLogger("test-webrtc") +logging.basicConfig(level=logging.INFO) +SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") + +@pytest.mark.trio +async def test_webrtc_transport_end_to_end() -> None: + async with trio.open_nursery() as nursery: + Key_pair = create_new_key_pair() + peer_id = ID.from_pubkey(Key_pair.public_key) + print(f"Peer ID: {peer_id}") + # Create Peer A + host_a = new_host(key_pair=Key_pair) + pubsub_a = Pubsub( + host_a, + GossipSub(protocols=[SIGNAL_PROTOCOL], degree=10, degree_low=3, degree_high=15), + None, + ) + transport_a = WebRTCTransport(host_a, pubsub_a) + await transport_a.start() + + # Create Peer B + host_b = new_host() + pubsub_b = Pubsub( + host_b, + GossipSub(protocols=[SIGNAL_PROTOCOL], degree=10, degree_low=3, degree_high=15), + None, + ) + transport_b = WebRTCTransport(host_b, pubsub_b) + await transport_b.start() + + list_addr= host_b.get_addrs() + peer_id = host_b.get_id() + # if "webrtc" not in Protocol.name : + add_protocol(Protocol(name="webrtc", code=288, codec=None)) + add_protocol(Protocol(name="webrtc-direct", code= 289, codec= None)) + add_protocol(Protocol(name="certhash", code= 292, codec= None)) + add_protocol(Protocol(name="uEiBRYnd2NEs_2ycHGcld_M94-cKNOoLZamuSaGz_ArLaUA", code= 293, codec= None)) + maddr_b = Multiaddr(f"/ip4/0.0.0.0/tcp/0/ws/p2p/{host_b.get_id()}/p2p-circuit/webrtc") + + list_multiaddr= [ + Multiaddr(f"/ip4/127.0.0.1/udp/9095/webrtc/p2p/{peer_id}"), + Multiaddr(f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{peer_id}/p2p-circuit/webrtc"), + Multiaddr(f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{peer_id}/p2p-circuit/webrtc"), + Multiaddr(f"/ip4/127.0.0.1/udp/9095/webrtc/p2p/{peer_id}"), + ] + + # nursery.start_soon(host_a.run, []) + # nursery.start_soon(host_b.run, []) + host_a.run(listen_addrs=LISTEN_MADDR) + host_b.run(listen_addrs=LISTEN_MADDR) + # await trio.sleep(2) + + certhash = CertificateManager().get_certhash() + print(f"Certificate PEM: {certhash}") + signal_webrtc_maddr = generate_webrtc_multiaddr("192.168.0.1", str(peer_id), certhash="uEiBRYnd2NEs_2ycHGcld_M94-cKNOoLZamuSaGz_ArLaUA") + signal_maddr = "/ip4/0.0.0.0/tcp/ws/0" + print(f"Signal WebRTC Multiaddr: {signal_maddr}") + print(f"Signal WebRTC Multiaddr: {signal_webrtc_maddr}") + + # conn = await transport_a.dial(maddr_b) + conn: WebRTCRawConnection = await transport_a.dial(signal_maddr) + + # Check the channel is open + assert conn is not None + assert conn.channel.readyState == "open" + + logger.info("[Test] WebRTC channel open. Sending message...") + + # Send message + test_msg = "Hello from A" + conn.channel.send(test_msg) + logger.info(f"[Test] Peer A sent: {test_msg}") + + # Listen on B + received = trio.Event() + def on_message(msg: str) -> None: + if msg == test_msg: + logger.info(f"[Test] Peer B received: {msg}") + received.set() + + for ch in transport_b.connected_peers.values(): + if ch.readyState == "open": + ch.on("message", on_message) + + conn.channel.send(test_msg) + + trio.fail_after(5) + + logger.info("[Test] Message exchange succeeded") + await host_a.close() + await host_b.close() + nursery.cancel_scope.cancel() + diff --git a/libp2p/transport/webrtc/utils.py b/libp2p/transport/webrtc/utils.py deleted file mode 100644 index 3599964ce..000000000 --- a/libp2p/transport/webrtc/utils.py +++ /dev/null @@ -1,52 +0,0 @@ -from multiaddr import ( - Multiaddr, -) -import logging - -logger = logging.getLogger("webrtc") -logging.basicConfig(level=logging.INFO) - -def parse_webrtc_multiaddr(multiaddr_str: str) -> Multiaddr: - """ - Parse, validate, and extract components from a WebRTC multiaddr. - - Expected format: - /ip4|dns4|dns6/
/tcp//p2p//p2p-circuit/webrtc - """ - try: - addr = Multiaddr(multiaddr_str) - protocols = [p.name for p in addr.protocols()] - - if "webrtc" not in protocols: - raise ValueError("Missing /webrtc protocol in multiaddr") - - if "p2p" not in protocols: - raise ValueError("Missing /p2p protocol (peer ID required)") - - # Extracting peer ID and address components - components = addr.items() - peer_id = None - ip_or_dns = None - port = None - - for proto, value in components: - if proto in ("ip4", "ip6", "dns4", "dns6"): - ip_or_dns = value - elif proto == "tcp": - port = value - elif proto == "p2p": - peer_id = value - - if not all([ip_or_dns, port, peer_id]): - raise ValueError("Incomplete multiaddr: Must include IP/DNS, TCP port, and Peer ID") - - return { - "multiaddr": addr, - "peer_id": peer_id, - "ip_or_dns": ip_or_dns, - "port": port - } - - except Exception as e: - logger.error(f"[parse_webrtc_multiaddr] Failed to parse multiaddr: {e}") - return None diff --git a/libp2p/transport/webrtc/webrtc.py b/libp2p/transport/webrtc/webrtc.py index 83154371c..e02fad504 100644 --- a/libp2p/transport/webrtc/webrtc.py +++ b/libp2p/transport/webrtc/webrtc.py @@ -2,32 +2,31 @@ import logging from typing import ( Any, - Callable, Optional, ) -from _collections_abc import ( - Coroutine, -) from aiortc import ( RTCDataChannel, RTCIceCandidate, RTCPeerConnection, RTCSessionDescription, ) +import anyio.from_thread from multiaddr import ( Multiaddr, ) import trio from libp2p.abc import ( - IListener, - ISecureConn, ITransport, + THandler, TProtocol, ) +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) from libp2p.host.basic_host import ( - BasicHost, + IHost, ) from libp2p.peer.id import ( ID, @@ -42,13 +41,19 @@ from .connection import ( WebRTCRawConnection, ) +from .gen_certhash import ( + CertificateManager, + parse_webrtc_maddr, +) from .listener import ( WebRTCListener, ) -from .utils import ( - parse_webrtc_multiaddr, +from .signal_service import ( + SignalService, ) +# from upgrader import TransportUpgrader + logger = logging.getLogger("webrtc") logging.basicConfig(level=logging.INFO) SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") @@ -56,41 +61,179 @@ class WebRTCTransport(ITransport): def __init__( - self, peer_id: ID, host: BasicHost, config: Optional[dict[str, Any]] = None + self, host: IHost, pubsub: Pubsub, config: Optional[dict[str, Any]] = None ): - self.peer_id = peer_id self.host = host + key_pair = create_new_key_pair() + self.peer_id = ID.from_pubkey(key_pair.public_key) self.config = config or {} - self.peer_connection = RTCPeerConnection() + self.certificate = CertificateManager().certificate self.data_channel: Optional[RTCDataChannel] = None - self.pubsub: Optional[Pubsub] = None + self.connected_peers: dict[str, RTCDataChannel] = {} + self.pubsub = pubsub + self._listeners: list[Multiaddr] = [] + self.peer_connection: RTCPeerConnection + config = {"iceServers": [...], "upgrader": Any} + self.ice_servers = self.config.get( + "iceServers", + [ + {"urls": "stun:stun.l.google.com:19302"}, + {"urls": "stun:stun1.l.google.com:19302"}, + ], + ) + + self.signal_service = SignalService(self.host) + self.upgrader = self.config.get("upgrader") + + def _create_peer_connection(self) -> RTCPeerConnection: + return RTCPeerConnection( + configuration={"iceServers": self.ice_servers}, + certificates=[self.certificate], + ) + + async def start(self) -> None: + await self.start_peer_discovery() - async def create_data_channel(self) -> None: - self.data_channel = self.peer_connection.createDataChannel("libp2p-webrtc") + async with trio.open_nursery() as nursery: + nursery.start_soon(self.handle_offer) + logger.info("[WebRTC] WebRTCTransport started and listening for direct offers") + + async def start_peer_discovery(self) -> None: + if not self.pubsub: + gossipsub = GossipSub( + protocols=[SIGNAL_PROTOCOL], degree=10, degree_low=3, degree_high=15 + ) + self.pubsub = Pubsub(self.host, gossipsub, None) + + topic = await self.pubsub.subscribe("webrtc-peer-discovery") + + async def handle_message() -> None: + async for msg in topic: + logger.info(f"Discovered Peer: {msg.data.decode()}") + + async with trio.open_nursery() as nursery: + nursery.start_soon(handle_message) + nursery.start_soon(self.handle_offer) + logger.info( + "[WebRTC] WebRTCTransport started and listening for direct offers" + ) + + await self.pubsub.publish( + "webrtc-peer-discovery", str(self.peer_id).encode() + ) + + def verify_peer_certificate(self, remote_cert, expected_certhash: str): + """ + Compute the certhash of the remote certificate and compare to expected. + """ + cert_mgr = CertificateManager() + actual_certhash = cert_mgr._compute_certhash(remote_cert) + if actual_certhash != expected_certhash: + raise ValueError( + f"Certhash: expected {expected_certhash}, got {actual_certhash}" + ) + + def verify_peer_id(self, remote_peer_id: str, expected_peer_id: str): + if remote_peer_id != expected_peer_id: + raise ValueError( + f"Peer ID mismatch: expected {expected_peer_id}, got {remote_peer_id}" + ) + + async def create_data_channel( + self, pc: RTCPeerConnection, label: str = "libp2p-webrtc" + ) -> RTCDataChannel: + channel = pc.createDataChannel(label) - @self.data_channel.on("open") + @channel.on("open") def on_open() -> None: - logger.info(f"Data channel open with peer {self.peer_id}") + logger.info("[WebRTC] Data channel open with peer") - @self.data_channel.on("message") + @channel.on("message") def on_message(message: Any) -> None: - logger.info(f"Received message from peer {self.peer_id}: {message}") + logger.info(f"[WebRTC] Message received: {message}") + + return channel + + def relay_message(self, message: Any, exclude_peer: Optional[str] = None) -> None: + """ + Relay incoming message to all other connected peers, excluding the sender. + """ + for pid, channel in list(self.connected_peers.items()): + if pid == exclude_peer: + continue + if channel.readyState != "open": + logger.warning(f"[Relay] Channel to {pid} not open. Removing.") + self.connected_peers.pop(pid, None) + continue + try: + channel.send(message) + logger.info(f"[Relay] Forwarded message to {pid}") + except Exception as e: + logger.exception(f"[Relay] Error sending to {pid}: {e}") + self.connected_peers.pop(pid, None) + + async def _handle_signal_message(self, peer_id: str, data: dict[str, Any]): + self.host.get_id() + msg_type = data.get("type") + if msg_type == "offer": + await self._handle_signal_offer(peer_id, data) + elif msg_type == "answer": + await self._handle_signal_answer(peer_id, data) + elif msg_type == "ice": + await self._handle_signal_ice(peer_id, data) + + async def _handle_signal_offer(self, peer_id: str, data: dict[str, Any]): + pc = self._create_peer_connection() + self.peer_connection = pc - async def handle_offer_from_peer(self, stream: Any, data: dict[str, Any]) -> None: - offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) - await self.peer_connection.setRemoteDescription(offer) + channel_ready = trio.Event() + + @pc.on("datachannel") + def on_datachannel(channel): + self.connected_peers[peer_id] = channel - answer = await self.peer_connection.createAnswer() - await self.peer_connection.setLocalDescription(answer) + @channel.on("open") + def on_open(): + channel_ready.set() - response: dict[str, Any] = { - "type": "answer", - "sdp": answer.sdp, - "sdpType": answer.type, - "peer_id": str(self.peer_id), - } + @channel.on("message") + def on_message(msg): + self.relay_message(msg, exclude_peer=peer_id) - await stream.write(json.dumps(response).encode()) + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await pc.setRemoteDescription(offer) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + await self.signal_service.send_answer( + peer_id, + { + "sdp": pc.localDescription.sdp, + "sdpType": pc.localDescription.type, + "certhash": CertificateManager()._compute_certhash( + self.certificate.x509 + ), + }, + ) + await channel_ready.wait() + + async def _handle_signal_answer(self, peer_id: str, data: dict[str, Any]): + answer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await self.peer_connection.setRemoteDescription(answer) + + async def _handle_signal_ice(self, peer_id: str, data: dict[str, Any]): + candidate = RTCIceCandidate( + component=data["component"], + foundation=data["foundation"], + priority=data["priority"], + ip=data["ip"], + protocol=data["protocol"], + port=data["port"], + type=data["candidateType"], + sdpMid=data["sdpMid"], + ) + await self.peer_connection.addIceCandidate(candidate) async def handle_answer_from_peer(self, data: dict[str, Any]) -> None: answer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) @@ -105,9 +248,28 @@ async def handle_ice_candidate(self, data: dict[str, Any]) -> None: protocol=data["protocol"], port=data["port"], type=data["candidateType"], + sdpMid=data["sdpMid"], ) await self.peer_connection.addIceCandidate(candidate) + async def create_listener(self, handler_func: THandler) -> WebRTCListener: + def on_new_stream(stream): + handler_func(stream) + + pc = self._create_peer_connection() + channel = await self.create_data_channel(pc, "webrtc-dial") + channel_ready = trio.Event() + + @channel.on("open") + def on_open(): + channel_ready.set() + + raw_conn = WebRTCRawConnection(self.peer_id, channel) + raw_conn.on_stream(on_new_stream) + if not self.host: + raise RuntimeError("Host not initialized") + return WebRTCListener(handler=handler_func) + async def handle_incoming_candidates( self, stream: Any, peer_connection: RTCPeerConnection ) -> None: @@ -124,32 +286,34 @@ async def handle_incoming_candidates( protocol=data["protocol"], port=data["port"], type=data["candidateType"], + sdpMid=data["sdpMid"], ) await peer_connection.addIceCandidate(candidate) except Exception as e: logger.error(f"[ICE Trickling] Error reading ICE candidate: {e}") + await stream.close() break - async def start_peer_discovery(self) -> None: - gossipsub = GossipSub(protocols=[], degree=10, degree_low=3, degree_high=15) - self.pubsub = Pubsub(self.host, gossipsub, None) - - topic = await self.pubsub.subscribe("webrtc-peer-discovery") - - async def handle_message() -> None: - async for msg in topic: - logger.info(f"Discovered Peer: {msg.data.decode()}") - - async with trio.open_nursery() as nursery: - nursery.start_soon(handle_message) + async def dial(self, maddr: Multiaddr) -> WebRTCRawConnection: + _, peer_id, certhash = parse_webrtc_maddr(maddr) + stream = await self.host.new_stream(peer_id, [SIGNAL_PROTOCOL]) - await self.pubsub.publish("webrtc-peer-discovery", str(self.peer_id).encode()) + pc = self._create_peer_connection() + channel = await self.create_data_channel(pc, "webrtc-dial") + channel_ready = trio.Event() + self.connected_peers[peer_id] = channel + # cert_pem= CertificateManager() + # cert: CertificateManager = cert_pem.generate_self_signed_cert() + # print(f"Certificate PEM: {cert}") - async def dial(self, maddr: Multiaddr) -> ISecureConn: - peer_id = parse_webrtc_multiaddr(maddr) - stream = await self.host.new_stream(peer_id, [SIGNAL_PROTOCOL]) + @channel.on("open") + def on_open(): + channel_ready.set() - pc = RTCPeerConnection() + @channel.on("message") + def on_message(msg): + logger.info(f"[Relay] Received from {peer_id}: {msg}") + self.relay_message(msg, exclude_peer=peer_id) @pc.on("icecandidate") def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: @@ -162,42 +326,213 @@ def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: "priority": candidate.priority, "port": candidate.port, "protocol": candidate.protocol, + "ip": candidate.ip, + "sdpMid": candidate.sdpMid, } trio.lowlevel.spawn_system_task(stream.write, json.dumps(msg).encode()) trio.lowlevel.spawn_system_task(self.handle_incoming_candidates, stream, pc) - channel = pc.createDataChannel("libp2p") - channel_ready = trio.Event() - - @channel.on("open") - def on_open() -> None: - channel_ready.set() - offer = await pc.createOffer() await pc.setLocalDescription(offer) + try: + # await self.signal_service.send_offer(peer_id, offer) + await self.signal_service.send_offer( + peer_id, + { + "sdp": pc.localDescription.sdp, + "sdpType": pc.localDescription.type, + "certhash": CertificateManager()._compute_certhash( + self.certificate.x509 + ), + }, + ) + except Exception as e: + logger.error(f"[Signaling] Failed to send offer to {peer_id}: {e}") + await stream.close() + raise + + await channel_ready.wait() + remote_cert = CertificateManager.getFingerprints() # remote cert for comparison + if not remote_cert: + raise ValueError("No remote certificate received") + remote_cert = remote_cert[0] + self.verify_peer_certificate(remote_cert, certhash) + self.verify_peer_id(peer_id, str(self.peer_id)) + await stream.write( json.dumps( { "type": "offer", - "peer_id": str(self.peer_id), + "peer_id": self.peer_id, "sdp": offer.sdp, "sdpType": offer.type, + "certhash": CertificateManager()._compute_certhash( + self.certificate.x509 + ), } ).encode() ) - answer_data = await stream.read() - answer_msg: dict[str, Any] = json.loads(answer_data.decode()) - answer = RTCSessionDescription(**answer_msg) - await pc.setRemoteDescription(answer) + try: + answer_data = await stream.read() + answer_msg: dict[str, Any] = json.loads(answer_data.decode()) + answer = RTCSessionDescription(**answer_msg) + await pc.setRemoteDescription(answer) + except Exception as e: + logger.error( + f"[Signaling] Failed to receive or process answer from {peer_id}: {e}" + ) + await stream.close() + raise await channel_ready.wait() - return WebRTCRawConnection(self.peer_id, channel) + raw_conn = WebRTCRawConnection(self.peer_id, channel) + logical_stream = await raw_conn.open_stream() + if self.upgrader: + upgraded_conn = await self.upgrader.upgrade_connection(logical_stream) + return upgraded_conn + else: + return logical_stream + + async def webrtc_direct_dial(self, maddr: Multiaddr) -> WebRTCRawConnection: + protocols = [p.name for p in maddr.protocols()] + if "webrtc-direct" in protocols: + logger.info("[Dial] Detected /webrtc-direct multiaddr....") + + ip, peer_id, certhash = parse_webrtc_maddr(maddr) + if not ip or not peer_id: + raise ValueError("Missing IP or Peer ID in webrtc-direct multiaddr") + logger.info( + f"Parsed IP={ip}, PeerID={peer_id}, Certhash={certhash or 'None'}" + ) + + pc = self._create_peer_connection() + channel = await self.create_data_channel( + pc, label="py-libp2p-webrtc-direct" + ) + channel_ready = trio.Event() + self.connected_peers[peer_id] = channel + + @channel.on("open") + def on_open() -> None: + logger.info(f"[webrtc-direct] Channel open with {peer_id}") + channel_ready.set() + + @channel.on("message") + def on_message(msg: Any) -> None: + logger.info(f"[Relay] Received from {peer_id}: {msg}") + self.relay_message(msg, exclude_peer=peer_id) + + offer = await anyio.from_thread.run_sync(pc.createOffer) + await anyio.from_thread.run_sync(pc.setLocalDescription, offer) + + if self.pubsub is None: + await self.start_peer_discovery() + await self.pubsub.publish( + f"webrtc-offer-{peer_id}", + json.dumps( + { + "peer_id": self.peer_id, + "sdp": offer.sdp, + "sdpType": offer.type, + "certhash": CertificateManager()._compute_certhash( + self.certificate.x509 + ), + } + ).encode(), + ) + + logger.info(f"[webrtc-direct] Sent offer to peer {self.peer_id} via pubsub") + + topic = await self.pubsub.subscribe(f"webrtc-answer-{self.peer_id}") + async for msg in topic: + answer_data = json.loads(msg.data.decode()) + answer = RTCSessionDescription(**answer_data) + # await pc.setRemoteDescription(answer) + await anyio.from_thread.run_sync(pc.setRemoteDescription, answer) + break + + await channel_ready.wait() + raw_conn = WebRTCRawConnection(peer_id, channel) + if self.upgrader: + upgraded_conn = await self.upgrader.upgrade_connection(raw_conn) + return upgraded_conn + else: + return raw_conn + else: + logger.info("[Dial] Falling back to regular signal-based WebRTC") + return await self.dial(maddr) + + async def handle_offer(self): + logger.info("[signal] Listening for incoming offers via SignalService") + # await self.signal_service.listen() + + async def _on_offer(msg): + try: + data = json.loads(msg.data.decode()) + remote_peer_id = data["peer_id"] + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + + pc = self._create_peer_connection() + logger.info( + f"[webrtc-direct] Received offer from peer {remote_peer_id}" + ) + channel_ready = trio.Event() + + @pc.on("datachannel") + def on_datachannel(channel): + logger.info( + f"[webrtc-direct] Datachannel received from {remote_peer_id}" + ) + self.connected_peers[remote_peer_id] = channel + + @channel.on("open") + def on_open(): + logger.info( + f"[webrtc-direct] Channel open with {remote_peer_id}" + ) + channel_ready.set() + + @channel.on("message") + def on_message(msg): + logger.info(f"[Relay] Received from {remote_peer_id}: {msg}") + self.relay_message(msg, exclude_peer=remote_peer_id) + + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await pc.setRemoteDescription(offer) + remote_cert = self.peer_connection.getRemoteCertificates()[0] + expected_certhash = data.get("certhash") + self.verify_peer_certificate(remote_cert, expected_certhash) + self.verify_peer_id(remote_peer_id, str(self.peer_id)) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + response_topic = f"webrtc-answer-{remote_peer_id}" + await self.pubsub.publish( + response_topic, + json.dumps( + { + "peer_id": str(self.peer_id), + "sdp": pc.localDescription.sdp, + "sdpType": pc.localDescription.type, + "certhash": CertificateManager()._compute_certhash( + self.certificate.x509 + ), + } + ).encode(), + ) + logger.info(f"ans sent to peer {remote_peer_id} via {response_topic}") + await channel_ready.wait() - async def create_listener( - self, handler: Callable[[WebRTCListener], Coroutine[Any, Any, None]] - ) -> IListener: - listener = await self.create_listener(handler=handler) - return listener + except Exception as e: + logger.error(f"[webrtc-direct] Error handling offer: {e}") + + offer_topic = f"webrtc-offer-{remote_peer_id}" + logger.info(f"[webrtc-direct] Subscribing to topic: {offer_topic}") + topic = await self.pubsub.subscribe(offer_topic) + + async for msg in topic: + await _on_offer(msg) From dfe6629a718bcf34bfdf9e6b3dd87f452f2f83d6 Mon Sep 17 00:00:00 2001 From: Neha Kumari Date: Tue, 3 Jun 2025 02:28:57 +0530 Subject: [PATCH 6/9] feat(webrtc): add maddr parsing, SDP munging& test coverage --- .pre-commit-config.yaml | 1 + docs/libp2p.transport.rst | 1 + docs/libp2p.transport.webrtc.rst | 69 +++ libp2p/transport/webrtc/connection.py | 4 +- libp2p/transport/webrtc/gen_certhash.py | 179 ++++++-- libp2p/transport/webrtc/listener.py | 100 ++-- libp2p/transport/webrtc/signal_service.py | 71 ++- .../transport/webrtc/test_gen_certificate.py | 50 ++ libp2p/transport/webrtc/test_listener.py | 90 ++++ libp2p/transport/webrtc/test_loopback.py | 179 -------- libp2p/transport/webrtc/test_signal.py | 95 ++++ libp2p/transport/webrtc/test_webrtc.py | 119 ----- .../webrtc/test_webrtc_direct_loopback.py | 119 +++++ libp2p/transport/webrtc/webrtc.py | 428 ++++++++++-------- 14 files changed, 914 insertions(+), 591 deletions(-) create mode 100644 docs/libp2p.transport.webrtc.rst create mode 100644 libp2p/transport/webrtc/test_gen_certificate.py create mode 100644 libp2p/transport/webrtc/test_listener.py delete mode 100644 libp2p/transport/webrtc/test_loopback.py create mode 100644 libp2p/transport/webrtc/test_signal.py delete mode 100644 libp2p/transport/webrtc/test_webrtc.py create mode 100644 libp2p/transport/webrtc/test_webrtc_direct_loopback.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1712b7f17..7d28fcc1d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,6 +51,7 @@ repos: language: system always_run: true pass_filenames: false + stages: [manual] - repo: local hooks: - id: check-rst-files diff --git a/docs/libp2p.transport.rst b/docs/libp2p.transport.rst index 0d92c48f5..9e96c7172 100644 --- a/docs/libp2p.transport.rst +++ b/docs/libp2p.transport.rst @@ -8,6 +8,7 @@ Subpackages :maxdepth: 4 libp2p.transport.tcp + libp2p.transport.webrtc Submodules ---------- diff --git a/docs/libp2p.transport.webrtc.rst b/docs/libp2p.transport.webrtc.rst new file mode 100644 index 000000000..1acbe085e --- /dev/null +++ b/docs/libp2p.transport.webrtc.rst @@ -0,0 +1,69 @@ +libp2p.transport.webrtc package +=============================== + +Submodules +---------- + +libp2p.transport.webrtc.connection module +----------------------------------------- + +.. automodule:: libp2p.transport.webrtc.connection + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.gen\_certhash module +-------------------------------------------- + +.. automodule:: libp2p.transport.webrtc.gen_certhash + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.listener module +--------------------------------------- + +.. automodule:: libp2p.transport.webrtc.listener + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.signal\_service module +---------------------------------------------- + +.. automodule:: libp2p.transport.webrtc.signal_service + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.test\_loopback module +--------------------------------------------- + +.. automodule:: libp2p.transport.webrtc.test_loopback + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.test\_webrtc module +------------------------------------------- + +.. automodule:: libp2p.transport.webrtc.test_webrtc + :members: + :show-inheritance: + :undoc-members: + +libp2p.transport.webrtc.webrtc module +------------------------------------- + +.. automodule:: libp2p.transport.webrtc.webrtc + :members: + :show-inheritance: + :undoc-members: + +Module contents +--------------- + +.. automodule:: libp2p.transport.webrtc + :members: + :show-inheritance: + :undoc-members: \ No newline at end of file diff --git a/libp2p/transport/webrtc/connection.py b/libp2p/transport/webrtc/connection.py index ec5e35e6b..c84a6a967 100644 --- a/libp2p/transport/webrtc/connection.py +++ b/libp2p/transport/webrtc/connection.py @@ -13,7 +13,7 @@ ) from libp2p.abc import ( - ISecureConn, + IRawConnection, ) from libp2p.peer.id import ( ID, @@ -26,7 +26,7 @@ logging.basicConfig(level=logging.INFO) -class WebRTCRawConnection(ISecureConn): +class WebRTCRawConnection(IRawConnection): def __init__(self, peer_id: ID, channel: RTCDataChannel): self.peer_id = peer_id self.channel = channel diff --git a/libp2p/transport/webrtc/gen_certhash.py b/libp2p/transport/webrtc/gen_certhash.py index eeb5a99a1..c111f4cbf 100644 --- a/libp2p/transport/webrtc/gen_certhash.py +++ b/libp2p/transport/webrtc/gen_certhash.py @@ -1,32 +1,44 @@ -import base58 +import base64 +import datetime import hashlib -import ssl from typing import ( Optional, - List, - Tuple, ) -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -import hashlib -import base64 -from cryptography import x509 -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.x509.oid import NameOID -import datetime -from multiaddr import Multiaddr -from typing import Tuple -from aiortc import RTCCertificate -from multiaddr.protocols import ( - Protocol, - add_protocol, + +from aiortc import ( + RTCCertificate, +) +import base58 +from cryptography import ( + x509, +) +from cryptography.hazmat.backends import ( + default_backend, +) +from cryptography.hazmat.primitives import ( + hashes, + serialization, +) +from cryptography.hazmat.primitives.asymmetric import ( + rsa, +) +from cryptography.x509.oid import ( + NameOID, +) +from multiaddr import ( + Multiaddr, +) + +from libp2p.peer.id import ( + ID, ) SIGNAL_PROTOCOL = "/libp2p/webrtc/signal/1.0.0" + class CertificateManager(RTCCertificate): def __init__(self): + self.x509 = None self.private_key = None self.certificate = None self.certhash = None @@ -35,9 +47,9 @@ def generate_self_signed_cert(self, common_name: str = "py-libp2p") -> None: self.private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048 ) - subject = issuer = x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, common_name) - ]) + subject = issuer = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, common_name)] + ) self.certificate = ( x509.CertificateBuilder() .subject_name(subject) @@ -57,7 +69,8 @@ def _compute_certhash(self, cert: x509.Certificate) -> str: return base64.urlsafe_b64encode(sha256_hash).decode("utf-8").rstrip("=") def get_certhash(self) -> str: - return self.certhash + # return self.certhash + return f"uEi{self.certhash}" def get_certificate_pem(self) -> bytes: return self.certificate.public_bytes(serialization.Encoding.PEM) @@ -70,55 +83,123 @@ def get_private_key_pem(self) -> bytes: ) -def parse_webrtc_maddr(maddr: Multiaddr) -> Tuple[str, int, str]: +class SDPMunger: + """Handle SDP modification for direct connections""" + + @staticmethod + def munge_offer(sdp: str, ip: str, port: int) -> str: + """Modify SDP offer for direct connection""" + lines = sdp.split("\n") + munged = [] + + for line in lines: + if line.startswith("a=candidate"): + # Modify ICE candidate to use provided IP/port + parts = line.split() + parts[4] = ip + parts[5] = str(port) + line = " ".join(parts) + munged.append(line) + + return "\n".join(munged) + + @staticmethod + def munge_answer(sdp: str, ip: str, port: int) -> str: + """Modify SDP answer for direct connection""" + return SDPMunger.munge_offer(sdp, ip, port) + + +def create_webrtc_multiaddr( + ip: str, peer_id: ID, certhash: str, direct: bool = False +) -> Multiaddr: + """Create WebRTC multiaddr with proper format""" + # For direct connections + if direct: + return Multiaddr( + f"/ip4/{ip}/udp/0/webrtc-direct" f"/certhash/{certhash}" f"/p2p/{peer_id}" + ) + + # For signaled connections + return Multiaddr(f"/ip4/{ip}/webrtc" f"/certhash/{certhash}" f"/p2p/{peer_id}") + # return Multiaddr(f"/ip4/{ip}/webrtc/p2p/{peer_id}") + + +def verify_certhash(remote_cert: x509.Certificate, expected_hash: str) -> bool: + """Verify remote certificate hash matches expected""" + der_bytes = remote_cert.public_bytes(serialization.Encoding.DER) + conv_hash = base64.urlsafe_b64encode(hashlib.sha256(der_bytes).digest()) + actual_hash = f"uEi{conv_hash.decode('utf-8').rstrip('=')}" + return actual_hash == expected_hash + + +def create_webrtc_direct_multiaddr(ip: str, port: int, peer_id: ID) -> Multiaddr: + """Create a WebRTC-direct multiaddr""" + # Format: /ip4//udp//webrtc-direct/p2p/ + return Multiaddr(f"/ip4/{ip}/udp/{port}/webrtc-direct/p2p/{peer_id}") + + +def parse_webrtc_maddr(maddr: Multiaddr) -> tuple[str, ID, str]: """ Parse a WebRTC multiaddr like: - /ip4/127.0.0.1/udp/5000/webrtc/certhash//p2p/ - Returns (ip, port, certhash) + /ip4/147.28.186.157/udp/9095/webrtc-direct/certhash/uEiDFVmAomKdAbivdrcIKdXGyuij_ax8b8at0GY_MJXMlwg/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit + /ip6/2604:1380:4642:6600::3/tcp/9095/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit/webrtc + /ip4/147.28.186.157/udp/9095/webrtc-direct/certhash/uEiDFVmAomKdAbivdrcIKdXGyuij_ax8b8at0GY_MJXMlwg/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit/webrtc + /ip4/127.0.0.1/udp/9000/webrtc-direct/certhash/uEia...1jI/p2p/12D3KooW...6HEh + Returns (ip, peer_id, certhash) """ - addr = Multiaddr(maddr) - ip = None - port = None - certhash = None + try: + if isinstance(maddr, str): + maddr = Multiaddr(maddr) + + parts = maddr.to_string().split("/") + + # Get IP (after ip4 or ip6) + ip_idx = parts.index("ip4" if "ip4" in parts else "ip6") + 1 + ip = parts[ip_idx] - for c in addr.protocols(): - if c.name == "ip4" or c.name == "ip6": - ip = addr.value_for_protocol(c.name) - elif c.name == "udp": - port = int(addr.value_for_protocol("udp")) - elif c.name == "certhash": - certhash = addr.value_for_protocol("certhash") + # Get certhash (after certhash) + certhash_idx = parts.index("certhash") + 1 + certhash = parts[certhash_idx] - if not ip or not port or not certhash: - raise ValueError("Invalid WebRTC multiaddress") + # Get peer ID (after p2p) + peer_id_idx = parts.index("p2p") + 1 + peer_id = parts[peer_id_idx] - return ip, port, certhash + if not all([ip, peer_id, certhash]): + raise ValueError("Missing required components in multiaddr") + return ip, peer_id, certhash -def generate_local_certhash(cert_pem: str) -> str: + except Exception as e: + raise ValueError(f"Invalid WebRTC ma: {e}") + + +def generate_local_certhash(cert_pem: bytes) -> bytes: cert = x509.load_pem_x509_certificate(cert_pem.encode(), default_backend()) - der_bytes = cert.public_bytes(encoding=ssl.Encoding.DER) + der_bytes = cert.public_bytes(encoding=serialization.Encoding.DER) digest = hashlib.sha256(der_bytes).digest() certhash = base58.b58encode(digest).decode() + print(f"local_certhash= {certhash}") return f"uEi{certhash}" # js-libp2p compatible def generate_webrtc_multiaddr( ip: str, peer_id: str, certhash: Optional[str] = None ) -> Multiaddr: - add_protocol(Protocol(291, "webrtc-direct", "webrtc-direct")) - add_protocol(Protocol(292, "certhash", "certhash")) - # certhash = generate_local_certhash() if not certhash: raise ValueError("certhash must be provided for /webrtc-direct multiaddr") - - certificate= RTCCertificate.generateCertificate() + + cert_mgr = CertificateManager() + certhash = cert_mgr.get_certhash() if not certhash else certhash + if not isinstance(peer_id, ID): + peer_id = ID(peer_id) + base = f"/ip4/{ip}/udp/9000/webrtc-direct/certhash/{certhash}/p2p/{peer_id}" - + return Multiaddr(base) -def filter_addresses(addrs: List[Multiaddr]) -> List[Multiaddr]: +def filter_addresses(addrs: list[Multiaddr]) -> list[Multiaddr]: """ Filters the given list of multiaddresses, returning only those that are valid for WebRTC transport. diff --git a/libp2p/transport/webrtc/listener.py b/libp2p/transport/webrtc/listener.py index 881f6ad7f..67959bc6c 100644 --- a/libp2p/transport/webrtc/listener.py +++ b/libp2p/transport/webrtc/listener.py @@ -1,11 +1,13 @@ import json import logging from typing import ( - Any, + Optional, ) from aiortc import ( + RTCConfiguration, RTCDataChannel, + RTCIceCandidate, RTCPeerConnection, RTCSessionDescription, ) @@ -21,19 +23,25 @@ from libp2p.abc import ( IListener, - THandler, TProtocol, ) + +# from .webrtc import ( +# WebRTCTransport, +# ) +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) from libp2p.host.basic_host import ( BasicHost, ) +from libp2p.peer.id import ( + ID, +) from .connection import ( WebRTCRawConnection, ) -from .gen_certhash import ( - CertificateManager, -) logger = logging.getLogger("webrtc") logging.basicConfig(level=logging.INFO) @@ -41,19 +49,45 @@ class WebRTCListener(IListener): - def __init__(self, handler: THandler): - self._handle_stream = handler - # self.peer_id = peer_id , peer_id: ID + def __init__(self): self.host: BasicHost = None + key_pair = create_new_key_pair() + self.peer_id = ID.from_pubkey(key_pair.public_key) + # self.transport = WebRTCTransport() self.conn_send_channel: MemorySendChannel[WebRTCRawConnection] self.conn_receive_channel: MemoryReceiveChannel[WebRTCRawConnection] - self.conn_send_channel, self.conn_receive_channel = trio.open_memory_channel(5) + self.conn_send_channel, self.conn_receive_channel = trio.open_memory_channel(10) self.certificate = str + self._listen_addrs: list[ + Multiaddr + ] = [] # ['/ip4/127.0.0.1/tcp/4001', '/ip4/127.0.0.1/tcp/4034/ws/p2p-circuit'] def set_host(self, host: BasicHost) -> None: self.host = host - async def listen(self, maddr: Multiaddr) -> bool: + async def listen(self, maddr: Multiaddr) -> None: + """Listen for both direct and signaled connections""" + # print(f"Listening on {maddr} + {maddr.protocols}") + if "webrtc-direct" in maddr: + await self._listen_direct(maddr) + else: + await self.listen_signaled(maddr) + + async def _listen_direct(self, maddr: Multiaddr) -> None: + """Listen for direct WebRTC connections""" + pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) + + @pc.on("datachannel") + def on_datachannel(channel): + conn = WebRTCRawConnection(self.peer_id, channel) + self.conn_send_channel.send_nowait(conn) + + @pc.on("connectionstatechange") + async def on_connectionstatechange(): + if pc.connectionState == "failed": + await pc.close() + + async def listen_signaled(self, maddr: Multiaddr) -> bool: if not self.host: raise RuntimeError("Host is not initialized in WebRTCListener") @@ -62,14 +96,41 @@ async def listen(self, maddr: Multiaddr) -> bool: self._handle_stream_wrapper, ) await self.host.get_network().listen(maddr) + if maddr not in self._listen_addrs: + self._listen_addrs.append(maddr) return True + def get_addrs(self) -> tuple[Multiaddr, ...]: + return tuple(self._listen_addrs) + async def accept(self) -> WebRTCRawConnection: return await self.conn_receive_channel.receive() + async def _accept_loop(self): + """Accept incoming connections""" + while self._listening: + try: + # Wait for incoming connections from the transport + await trio.sleep(0.1) # Prevent busy waiting + + # Check connection pool for new connections + for peer_id, channel in self.transport.connection_pool.channels.items(): + if ( + channel.readyState == "open" + and peer_id not in self._processed_connections + ): + self._processed_connections.add(peer_id) + raw_conn = WebRTCRawConnection(self.transport.peer_id, channel) + await self.accept_queue.put(raw_conn) + + except Exception as e: + logger.error(f"[Listener] Error in accept loop: {e}") + await trio.sleep(1.0) + async def close(self) -> None: await self.conn_send_channel.aclose() await self.conn_receive_channel.aclose() + logger.info("[Listener] Closed") async def _handle_stream_wrapper(self, stream: trio.SocketStream) -> None: try: @@ -97,7 +158,7 @@ def on_open() -> None: ) @pc.on("icecandidate") - async def on_ice_candidate(candidate: Any) -> None: + async def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: if candidate: msg = { "type": "ice", @@ -105,9 +166,10 @@ async def on_ice_candidate(candidate: Any) -> None: "component": candidate.component, "foundation": candidate.foundation, "priority": candidate.priority, - "ip": candidate.address, + "ip": candidate.ip, "port": candidate.port, "protocol": candidate.protocol, + "sdpMid": candidate.sdpMid, } try: await stream.send_all(json.dumps(msg).encode()) @@ -130,17 +192,3 @@ async def on_ice_candidate(candidate: Any) -> None: await channel_ready.wait() await pc.close() - - def get_addrs(self) -> list[Multiaddr]: - peer_id = self.host.get_id() - certhash = CertificateManager()._compute_certhash(self.certificate.x509) - - base = "/ip4/127.0.0.1/tcp/0" - maddr_str = f"{base}/webrtc-direct/certhash/{certhash}/p2p/{peer_id}" - - try: - maddr = Multiaddr(maddr_str) - return [maddr] - except Exception as e: - logger.error(f"[WebRTCTransport] Failed to create listen Multiaddr: {e}") - return [] diff --git a/libp2p/transport/webrtc/signal_service.py b/libp2p/transport/webrtc/signal_service.py index c9a0db675..9fa1d5264 100644 --- a/libp2p/transport/webrtc/signal_service.py +++ b/libp2p/transport/webrtc/signal_service.py @@ -1,26 +1,42 @@ +from collections.abc import ( + Awaitable, +) import json -from typing import Callable, Awaitable, Dict, Optional -from multiaddr import Multiaddr -from libp2p.network.connection.raw_connection import RawConnection -from libp2p.peer.id import ID -from libp2p.abc import TProtocol, INotifee, INetStream, IHost +from typing import ( + Callable, +) + from aiortc import ( RTCIceCandidate, -) +) + +from libp2p.abc import ( + IHost, + INetStream, + INotifee, + TProtocol, +) +from libp2p.peer.id import ( + ID, +) + SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") + class SignalService(INotifee): def __init__(self, host: IHost): self.host = host self.signal_protocol = SIGNAL_PROTOCOL self._handlers: dict[str, Callable[[dict, str], Awaitable[None]]] = {} - def set_handler(self, msg_type: str, handler: Callable[[dict, str], Awaitable[None]]): + def set_handler( + self, msg_type: str, handler: Callable[[dict, str], Awaitable[None]] + ): self._handlers[msg_type] = handler async def listen(self): - await self.host.set_stream_handler(self.signal_protocol, self.handle_signal) - + self.host.set_stream_handler(self.signal_protocol, self.handle_signal) + async def handle_signal(self, stream: INetStream) -> None: peer_id = stream.muxed_conn.peer_id reader = stream @@ -40,7 +56,7 @@ async def handle_signal(self, stream: INetStream) -> None: print(f"Error in signal handler for {peer_id}: {e}") break - async def send_signal(self, peer_id: ID, message: Dict): + async def send_signal(self, peer_id: ID, message: dict): try: stream = await self.host.new_stream(peer_id, [self.signal_protocol]) await stream.write(json.dumps(message).encode()) @@ -49,23 +65,32 @@ async def send_signal(self, peer_id: ID, message: Dict): print(f"Failed to send signal to {peer_id}: {e}") async def send_offer(self, peer_id: ID, sdp: str, sdp_type: str, certhash: str): - await self.send_signal(peer_id, {"type": "offer", "sdp": sdp, "sdpType": sdp_type, "certhash": certhash}) + await self.send_signal( + peer_id, + {"type": "offer", "sdp": sdp, "sdpType": sdp_type, "certhash": certhash}, + ) async def send_answer(self, peer_id: ID, sdp: str, sdp_type: str, certhash: str): - await self.send_signal(peer_id, {"type": "answer", "sdp": sdp, "sdpType": sdp_type, "certhash": certhash}) + await self.send_signal( + peer_id, + {"type": "answer", "sdp": sdp, "sdpType": sdp_type, "certhash": certhash}, + ) async def send_ice_candidate(self, peer_id: ID, candidate: RTCIceCandidate): - await self.send_signal(peer_id, { - "type": "ice", - "candidateType": candidate.type, - "component": candidate.component, - "foundation": candidate.foundation, - "priority": candidate.priority, - "ip": candidate.ip, - "port": candidate.port, - "protocol": candidate.protocol, - "sdpMid": candidate.sdpMid, - }) + await self.send_signal( + peer_id, + { + "type": "ice", + "candidateType": candidate.type, + "component": candidate.component, + "foundation": candidate.foundation, + "priority": candidate.priority, + "ip": candidate.ip, + "port": candidate.port, + "protocol": candidate.protocol, + "sdpMid": candidate.sdpMid, + }, + ) async def connected(self, network, conn): pass diff --git a/libp2p/transport/webrtc/test_gen_certificate.py b/libp2p/transport/webrtc/test_gen_certificate.py new file mode 100644 index 000000000..0ca52adc3 --- /dev/null +++ b/libp2p/transport/webrtc/test_gen_certificate.py @@ -0,0 +1,50 @@ +from .gen_certhash import ( + CertificateManager +) +import pytest +import trio +from cryptography.x509.oid import ( + NameOID, +) + + + +# Certificate generation with default common name +@pytest.mark.trio +async def test_generate_self_signed_cert_with_default_common_name(): + cert_manager = CertificateManager() + + cert_manager.generate_self_signed_cert() + + assert cert_manager.certificate is not None + assert cert_manager.private_key is not None + assert cert_manager.certhash is not None + + # Verify the common name is the default "py-libp2p" + cert_subject = cert_manager.certificate.subject + common_name = cert_subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value + assert common_name == "py-libp2p" + + # Verify certhash format (base64 URL-safe encoded string) + certhash = cert_manager.get_certhash() + assert isinstance(certhash, str) + assert "+" not in certhash + assert "/" not in certhash + + + +# Accessing certhash before certificate generation +@pytest.mark.trio +async def test_get_certhash_before_certificate_generation(): + cert_manager = CertificateManager() + + assert cert_manager.certificate is None + assert cert_manager.private_key is None + assert cert_manager.certhash is None + + certhash = cert_manager.get_certhash() + assert certhash == 'uEiNone' # Default value when no certificate is generated + + # generate the certificate and verify certhash is available + cert_manager.generate_self_signed_cert() + assert cert_manager.get_certhash() is not None \ No newline at end of file diff --git a/libp2p/transport/webrtc/test_listener.py b/libp2p/transport/webrtc/test_listener.py new file mode 100644 index 000000000..998a204b3 --- /dev/null +++ b/libp2p/transport/webrtc/test_listener.py @@ -0,0 +1,90 @@ +import pytest +import trio +from multiaddr import Multiaddr + +from libp2p.transport.webrtc.listener import WebRTCListener, SIGNAL_PROTOCOL +from libp2p.transport.webrtc.connection import WebRTCRawConnection + +@pytest.mark.trio +async def test_listen_and_accept_direct_connection(): + listener = WebRTCListener() + maddr = Multiaddr("/ip4/127.0.0.1/tcp/9999/webrtc-direct") + await listener.listen(maddr) + + class DummyChannel: + def __init__(self): + self.label = "test" + self._on_message = None + self._on_open = None + self.readyState = "open" + def on(self, event): + def decorator(fn): + if event == "message": + self._on_message = fn + elif event == "open": + self._on_open = fn + return fn + return decorator + def send(self, data): + pass + + dummy_channel = DummyChannel() + + # creating a WebRTCRawConnection and sending the listener's channel + conn = WebRTCRawConnection(listener.peer_id, dummy_channel) + await listener.conn_send_channel.send(conn) + await listener.accept() + + maddr = "/ip4/127.0.0.1/udp/9000/webrtc-direct/certhash/uEiqfMpAA6QOH0DT7YC5ggjBBG-c3CqLqbhQ4ovz4q6NyY/p2p/12D3KooWGiZiY9Vz2CCbbDkJATUrgZ6b6ov7f6AxZPkWR573V3Bx" + # result = await listener.listen(maddr) + # assert result is True + # addrs = listener.get_addrs() + # assert maddr in addrs + # assert isinstance(addrs, tuple) + # assert len(addrs) == 1 + + #Accept the connection + # accepted_conn = trio.move_on_after(1) + + # assert isinstance(accepted_conn, WebRTCRawConnection) + # assert accepted_conn.peer_id == listener.peer_id + # assert hasattr(accepted_conn, "channel") + # assert accepted_conn.channel.label == "test" + + +class DummyNetwork: + def __init__(self): + self.listen_called_with = [] + self._peer_id = b"dummy_peer_id" + async def listen(self, maddr): + self.listen_called_with.append(maddr) + return True + def get_peer_id(self): + return self._peer_id + +class DummyHost: + def __init__(self): + self.stream_handlers = {} + self._network = DummyNetwork() + def set_stream_handler(self, protocol_id, handler): + self.stream_handlers[protocol_id] = handler + def get_network(self): + return self._network + def get_id(self): + return self._network.get_peer_id() + +@pytest.mark.trio +async def test_listen_signaled_registers_stream_handler(): + listener = WebRTCListener() + dummy_host = DummyHost() + listener.set_host(dummy_host) + maddr = Multiaddr("/ip4/127.0.0.1/tcp/12345/ws/p2p-webrtc-star") + result = await listener.listen_signaled(maddr) + # Check that the stream handler was registered for the signal protocol + assert SIGNAL_PROTOCOL in dummy_host.stream_handlers + # Check that the network's listen was called with the correct multiaddr + assert maddr in dummy_host.get_network().listen_called_with + # Check that the listen_signaled returns True + assert result is True + # Check that the address is added to the listener's addrs + assert maddr in listener.get_addrs() \ No newline at end of file diff --git a/libp2p/transport/webrtc/test_loopback.py b/libp2p/transport/webrtc/test_loopback.py deleted file mode 100644 index b96c421c5..000000000 --- a/libp2p/transport/webrtc/test_loopback.py +++ /dev/null @@ -1,179 +0,0 @@ -import json -import logging - -from multiaddr import ( - Multiaddr, -) -from multiaddr.protocols import ( - Protocol, - add_protocol, -) -import trio -from trio import ( - Nursery, -) - -from libp2p import ( - new_host, -) -from libp2p.crypto.ed25519 import ( - create_new_key_pair, -) -from libp2p.host.basic_host import ( - IHost, -) -from libp2p.peer.id import ( - ID, -) -from libp2p.peer.peerinfo import ( - info_from_p2p_addr, -) -from libp2p.peer.peerstore import ( - PeerInfo, -) -from libp2p.pubsub.gossipsub import ( - GossipSub, -) -from libp2p.pubsub.pubsub import ( - Pubsub, -) - -from .connection import ( - WebRTCRawConnection, -) -from .gen_certhash import ( - filter_addresses, -) -from .webrtc import ( - SIGNAL_PROTOCOL, - WebRTCTransport, -) - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("webrtc-loopback-test") - - -async def build_host(name: str) -> tuple[IHost, ID, WebRTCTransport]: - key_pair = create_new_key_pair() - peer_id = ID.from_pubkey(key_pair.public_key) - key_pair.private_key - key_pair.public_key - logger.info(f"[{name}] Peer ID: {peer_id}") - - host = new_host() - pubsub = Pubsub( - host, - GossipSub(protocols=[SIGNAL_PROTOCOL], degree=10, degree_low=3, degree_high=15), - None, - ) - webrtc_transport = WebRTCTransport(host, pubsub) - - return host, peer_id, webrtc_transport - - -async def run_loopback_test(nursery: Nursery) -> None: - host_b, peer_id_b, webrtc_transport_b = await build_host("Server") - logger.info(f"[B] Peer ID: {peer_id_b}") - addrs = host_b.get_addrs() - PeerInfo(peer_id_b, addrs) - listen = host_b.run(addrs) - logger.info(f"[B] Listening Addrs: {addrs} --- {listen}") - logger.info(f"[B] Active Addrs: {host_b.get_live_peers()}") - webrtc_proto = Protocol(name="webrtc", code=277, codec=None) - add_protocol(webrtc_proto) - # webrtc_conn = await webrtc_transport_b.create_listener(handler_func=) - - # await webrtc_conn.listen( - # Multiaddr(f"/ip4/127.0.0.1/tcp/9095/ws/p2p/{peer_id_b}/p2p-circuit/webrtc"), - # nursery, - # ) - - logger.info("[B] Listening WebRTC setup complete.") - - async def act_as_server() -> None: - try: - logger.info("[B] Waiting to accept connection...") - active_maddr = host_b.get_addrs() - info = info_from_p2p_addr(active_maddr[0]) - conn = host_b.connect(info) - logger.info(f"[B] Connection accepted. {conn}") - - stream = await host_b.new_stream(peer_id_b, [SIGNAL_PROTOCOL]) - offer_data = await stream.read() - offer_json = json.loads(offer_data.decode()) - - await webrtc_transport_b._handle_signal_message(peer_id_b, offer_json) - answer = await webrtc_transport_b.peer_connection.createAnswer() - await webrtc_transport_b.peer_connection.setLocalDescription(answer) - - await stream.write( - json.dumps({"sdp": answer.sdp, "sdpType": answer.type}).encode() - ) - - await trio.sleep(1) - - if webrtc_transport_b.data_channel is not None: - raw_conn = WebRTCRawConnection( - peer_id_b, webrtc_transport_b.data_channel - ) - msg = await raw_conn.read() - logger.info(f"[B] Received: {msg.decode()}") - await raw_conn.write(b"Reply from B") - await raw_conn.close() - else: - logger.error("[B] Data channel not established!") - except Exception as e: - logger.error(f"[B] Error in act_as_server: {e}") - - async def act_as_client() -> None: - host_a, peer_id_a, webrtc_client = await build_host("Client") - pc = webrtc_client._create_peer_connection() - await webrtc_client.create_data_channel(pc) - - valid = [ - Multiaddr(f"/ip4/127.0.0.1/udp/9095/webrtc/p2p/{peer_id_a}"), - Multiaddr(f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{peer_id_a}/p2p-circuit/webrtc"), - Multiaddr(f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{peer_id_b}/p2p-circuit/webrtc"), - Multiaddr(f"/ip4/127.0.0.1/udp/9095/webrtc/p2p/{peer_id_b}"), - ] - - maddr = filter_addresses(valid) - host_a.get_network().peerstore.add_addr(peer_id_a, maddr, 3000) - logger.info(f"[A] Peerstore updated with address: {maddr}") - - stream = await host_a.new_stream(peer_id_a, [SIGNAL_PROTOCOL]) - - offer = await webrtc_client.peer_connection.createOffer() - await webrtc_client.peer_connection.setLocalDescription(offer) - await stream.write( - json.dumps({"sdp": offer.sdp, "sdpType": offer.type}).encode() - ) - - answer_data = await stream.read() - answer_json = json.loads(answer_data.decode()) - await webrtc_client.handle_answer_from_peer(answer_json) - - await trio.sleep(1) - - if webrtc_client.data_channel is not None: - conn = WebRTCRawConnection(peer_id_a, webrtc_client.data_channel) - await conn.write(b"Hello from A") - reply = await conn.read() - logger.info(f"[A] Received: {reply.decode()}") - await conn.close() - else: - logger.error("[A] Data channel not established!") - - async with trio.open_nursery() as nursery: - nursery.start_soon(act_as_server) - await trio.sleep(1.5) - nursery.start_soon(act_as_client) - - -async def run_loopback_main() -> None: - async with trio.open_nursery() as nursery: - await run_loopback_test(nursery) - - -if __name__ == "__main__": - trio.run(lambda: run_loopback_main()) diff --git a/libp2p/transport/webrtc/test_signal.py b/libp2p/transport/webrtc/test_signal.py new file mode 100644 index 000000000..c61fc737b --- /dev/null +++ b/libp2p/transport/webrtc/test_signal.py @@ -0,0 +1,95 @@ +from unittest.mock import Mock +import json +import trio +from wsgiref.types import InputStream +import pytest +from libp2p.abc import ( + IHost, + INetStream, + INotifee, + TProtocol, +) +from libp2p.peer.id import ( + ID, +) +from .signal_service import ( + SignalService +) + +SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") + + +class TestSignalService: + + # Registering a handler for a specific message type and receiving that message type + @pytest.mark.trio + async def test_register_handler_and_receive_message(self): + mock_host = Mock(spec=IHost) + mock_stream = Mock(spec=INetStream) + mock_muxed_conn = Mock() + mock_peer_id = ID(b"test_peer_id") + + mock_muxed_conn.peer_id = mock_peer_id + mock_stream.muxed_conn = mock_muxed_conn + + message = {"type": "test_type", "data": "test_data"} + encoded_message = json.dumps(message).encode() + + # Configure read to return the message once, then empty data + mock_stream.read.side_effect = [encoded_message, b""] + signal_service = SignalService(mock_host) + + received_message = None + received_peer_id = None + handler_called_event = trio.Event() + + async def test_handler(msg, peer_id): + nonlocal received_message, received_peer_id + received_message = msg + received_peer_id = peer_id + handler_called_event.set() + + signal_service.set_handler("test_type", test_handler) + await signal_service.listen() + + mock_host.set_stream_handler.assert_called_once_with( + SIGNAL_PROTOCOL, signal_service.handle_signal + ) + + async with trio.open_nursery() as nursery: + nursery.start_soon(signal_service.handle_signal, mock_stream) + await handler_called_event.wait() + + assert received_message == message + assert received_peer_id == str(mock_peer_id) + assert mock_stream.read.call_count == 2 + + + # Handling empty data received from a stream + @pytest.mark.trio + async def test_handle_empty_data(self): + mock_host = Mock(spec=IHost) + mock_stream = Mock(spec=INetStream) + mock_muxed_conn = Mock() + mock_peer_id = ID(b"test_peer_id") + + mock_muxed_conn.peer_id = mock_peer_id + mock_stream.muxed_conn = mock_muxed_conn + + # Configure read to return empty data immediately + mock_stream.read.return_value = b"" + + signal_service = SignalService(mock_host) + + handler_called = False + + async def test_handler(msg, peer_id): + nonlocal handler_called + handler_called = True + + signal_service.set_handler("test_type", test_handler) + + await signal_service.handle_signal(mock_stream) + + assert not handler_called + mock_stream.read.assert_called_once_with(4096) \ No newline at end of file diff --git a/libp2p/transport/webrtc/test_webrtc.py b/libp2p/transport/webrtc/test_webrtc.py deleted file mode 100644 index 7cb79fa9b..000000000 --- a/libp2p/transport/webrtc/test_webrtc.py +++ /dev/null @@ -1,119 +0,0 @@ -import trio -import pytest -from libp2p.tools.constants import LISTEN_MADDR -from .webrtc import WebRTCTransport -from libp2p import new_host -from libp2p.pubsub.gossipsub import GossipSub -from libp2p.pubsub.pubsub import Pubsub -from multiaddr import Multiaddr -from libp2p.abc import TProtocol -from multiaddr.protocols import ( - Protocol, - add_protocol, -) -from .connection import WebRTCRawConnection -import logging -from libp2p.crypto.ed25519 import ( - create_new_key_pair, -) -from libp2p.host.basic_host import ( - IHost, -) -from libp2p.peer.id import ( - ID, -) -from .gen_certhash import CertificateManager, generate_webrtc_multiaddr - -logger = logging.getLogger("test-webrtc") -logging.basicConfig(level=logging.INFO) -SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") - -@pytest.mark.trio -async def test_webrtc_transport_end_to_end() -> None: - async with trio.open_nursery() as nursery: - Key_pair = create_new_key_pair() - peer_id = ID.from_pubkey(Key_pair.public_key) - print(f"Peer ID: {peer_id}") - # Create Peer A - host_a = new_host(key_pair=Key_pair) - pubsub_a = Pubsub( - host_a, - GossipSub(protocols=[SIGNAL_PROTOCOL], degree=10, degree_low=3, degree_high=15), - None, - ) - transport_a = WebRTCTransport(host_a, pubsub_a) - await transport_a.start() - - # Create Peer B - host_b = new_host() - pubsub_b = Pubsub( - host_b, - GossipSub(protocols=[SIGNAL_PROTOCOL], degree=10, degree_low=3, degree_high=15), - None, - ) - transport_b = WebRTCTransport(host_b, pubsub_b) - await transport_b.start() - - list_addr= host_b.get_addrs() - peer_id = host_b.get_id() - # if "webrtc" not in Protocol.name : - add_protocol(Protocol(name="webrtc", code=288, codec=None)) - add_protocol(Protocol(name="webrtc-direct", code= 289, codec= None)) - add_protocol(Protocol(name="certhash", code= 292, codec= None)) - add_protocol(Protocol(name="uEiBRYnd2NEs_2ycHGcld_M94-cKNOoLZamuSaGz_ArLaUA", code= 293, codec= None)) - maddr_b = Multiaddr(f"/ip4/0.0.0.0/tcp/0/ws/p2p/{host_b.get_id()}/p2p-circuit/webrtc") - - list_multiaddr= [ - Multiaddr(f"/ip4/127.0.0.1/udp/9095/webrtc/p2p/{peer_id}"), - Multiaddr(f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{peer_id}/p2p-circuit/webrtc"), - Multiaddr(f"/ip4/127.0.0.1/tcp/4001/ws/p2p/{peer_id}/p2p-circuit/webrtc"), - Multiaddr(f"/ip4/127.0.0.1/udp/9095/webrtc/p2p/{peer_id}"), - ] - - # nursery.start_soon(host_a.run, []) - # nursery.start_soon(host_b.run, []) - host_a.run(listen_addrs=LISTEN_MADDR) - host_b.run(listen_addrs=LISTEN_MADDR) - # await trio.sleep(2) - - certhash = CertificateManager().get_certhash() - print(f"Certificate PEM: {certhash}") - signal_webrtc_maddr = generate_webrtc_multiaddr("192.168.0.1", str(peer_id), certhash="uEiBRYnd2NEs_2ycHGcld_M94-cKNOoLZamuSaGz_ArLaUA") - signal_maddr = "/ip4/0.0.0.0/tcp/ws/0" - print(f"Signal WebRTC Multiaddr: {signal_maddr}") - print(f"Signal WebRTC Multiaddr: {signal_webrtc_maddr}") - - # conn = await transport_a.dial(maddr_b) - conn: WebRTCRawConnection = await transport_a.dial(signal_maddr) - - # Check the channel is open - assert conn is not None - assert conn.channel.readyState == "open" - - logger.info("[Test] WebRTC channel open. Sending message...") - - # Send message - test_msg = "Hello from A" - conn.channel.send(test_msg) - logger.info(f"[Test] Peer A sent: {test_msg}") - - # Listen on B - received = trio.Event() - def on_message(msg: str) -> None: - if msg == test_msg: - logger.info(f"[Test] Peer B received: {msg}") - received.set() - - for ch in transport_b.connected_peers.values(): - if ch.readyState == "open": - ch.on("message", on_message) - - conn.channel.send(test_msg) - - trio.fail_after(5) - - logger.info("[Test] Message exchange succeeded") - await host_a.close() - await host_b.close() - nursery.cancel_scope.cancel() - diff --git a/libp2p/transport/webrtc/test_webrtc_direct_loopback.py b/libp2p/transport/webrtc/test_webrtc_direct_loopback.py new file mode 100644 index 000000000..411555b7c --- /dev/null +++ b/libp2p/transport/webrtc/test_webrtc_direct_loopback.py @@ -0,0 +1,119 @@ +import logging +from multiaddr.protocols import add_protocol, Protocol, P_WEBRTC_DIRECT +import trio +from libp2p import new_host +from libp2p.host.basic_host import BasicHost +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.peer.id import ID +from libp2p.pubsub.gossipsub import GossipSub +from libp2p.pubsub.pubsub import Pubsub +from .gen_certhash import CertificateManager, create_webrtc_direct_multiaddr +from .webrtc import WebRTCTransport +from multiaddr import Multiaddr +import anyio +import anyio.to_thread +from contextlib import asynccontextmanager + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("webrtc-direct-loopback-test") + +async def build_host_and_transport(name: str): + key_pair = create_new_key_pair() + peer_id = ID.from_pubkey(key_pair.public_key) + logger.info(f"[{name}] Peer ID: {peer_id}") + + host = new_host() + pubsub = Pubsub( + host, + GossipSub(protocols=["/libp2p/webrtc/signal/1.0.0"], degree=10, degree_low=3, degree_high=15), + None, + ) + webrtc_transport = WebRTCTransport(host, pubsub) + return host, peer_id, webrtc_transport + +async def run_webrtc_direct_loopback_test(): + # Server (listener) + host_b, peer_id_b, webrtc_transport_b = await build_host_and_transport("Server") + cert_mgr_b = CertificateManager() + cert_mgr_b.generate_self_signed_cert() + + maddr_b = create_webrtc_direct_multiaddr( + ip="127.0.0.1", + port=9000, + peer_id=peer_id_b + ) + logger.info(f"[B] Listening on: {maddr_b}") + + listener = await webrtc_transport_b.create_listener() + listener.set_host(host_b) + await listener.listen(maddr_b) + listener_maddr = listener.get_addrs() + logger.info(f"[B] Listener Multiaddr: {listener_maddr}") + + # Client (dialer) + host_a, peer_id_a, webrtc_transport_a = await build_host_and_transport("Client") + cert_mgr_a = CertificateManager() + cert_mgr_a.generate_self_signed_cert() + + async def server_logic(): + try: + logger.info("[B] Waiting for incoming WebRTC-Direct connection...") + raw_conn = await listener.accept() + logger.info("[B] WebRTC-Direct connection accepted") + + # Testing bidirectional communication + msg = await raw_conn.read() + logger.info(f"[B] Received: {msg.decode()}") + await raw_conn.write(b"Reply from B (webrtc-direct)") + + # Testing stream handling + stream = await raw_conn.open_stream() + await stream.write(b"Stream test from B") + stream_data = await stream.read() + logger.info(f"[B] Stream data received: {stream_data.decode()}") + + await raw_conn.close() + logger.info("[B] Connection closed") + except Exception as e: + logger.error(f"[B] Error in server_logic: {e}") + raise + + async def client_logic(): + try: + await trio.sleep(1.0) + logger.info("[A] Dialing WebRTC-Direct...") + + with trio.move_on_after(60) as cancel_scope: # Add timeout + conn = await webrtc_transport_a.webrtc_direct_dial(maddr_b) + + if cancel_scope.cancelled_caught: + logger.error("[A] Connection attempt timed out") + return + + logger.info("[A] WebRTC-Direct connection established.") + await conn.write(b"Hello from A (webrtc-direct)") + reply = await conn.read() + logger.info(f"[A] Received: {reply.decode()}") + await conn.close() + + # Test stream handling + # stream = await conn.open_stream() + # await stream.write(b"Stream test from A") + # stream_data = await stream.read() + # logger.info(f"[A] Stream data received: {stream_data.decode()}") + + # await conn.close() + logger.info("[A] Connection closed") + except Exception as e: + logger.error(f"[A] Error in client_logic: {e}") + raise + + async with trio.open_nursery() as nursery: + nursery.start_soon(server_logic) + nursery.start_soon(client_logic) + +async def run_main(): + await run_webrtc_direct_loopback_test() + +if __name__ == "__main__": + trio.run(run_main) \ No newline at end of file diff --git a/libp2p/transport/webrtc/webrtc.py b/libp2p/transport/webrtc/webrtc.py index e02fad504..55b756189 100644 --- a/libp2p/transport/webrtc/webrtc.py +++ b/libp2p/transport/webrtc/webrtc.py @@ -6,27 +6,32 @@ ) from aiortc import ( + RTCConfiguration, RTCDataChannel, RTCIceCandidate, + RTCIceServer, RTCPeerConnection, RTCSessionDescription, ) -import anyio.from_thread from multiaddr import ( Multiaddr, ) +from multiaddr.protocols import ( + P_CERTHASH, + P_WEBRTC, + P_WEBRTC_DIRECT, +) import trio from libp2p.abc import ( ITransport, - THandler, TProtocol, ) from libp2p.crypto.ed25519 import ( create_new_key_pair, ) from libp2p.host.basic_host import ( - IHost, + BasicHost, ) from libp2p.peer.id import ( ID, @@ -43,6 +48,8 @@ ) from .gen_certhash import ( CertificateManager, + SDPMunger, + generate_local_certhash, parse_webrtc_maddr, ) from .listener import ( @@ -52,8 +59,6 @@ SignalService, ) -# from upgrader import TransportUpgrader - logger = logging.getLogger("webrtc") logging.basicConfig(level=logging.INFO) SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") @@ -61,35 +66,48 @@ class WebRTCTransport(ITransport): def __init__( - self, host: IHost, pubsub: Pubsub, config: Optional[dict[str, Any]] = None + self, host: BasicHost, pubsub: Pubsub, config: Optional[dict[str, Any]] = None ): self.host = host key_pair = create_new_key_pair() self.peer_id = ID.from_pubkey(key_pair.public_key) self.config = config or {} - self.certificate = CertificateManager().certificate + cert_mgr = CertificateManager() + cert_mgr.generate_self_signed_cert() + self.cert_mgr = cert_mgr + self.certificate = cert_mgr.get_certhash() self.data_channel: Optional[RTCDataChannel] = None self.connected_peers: dict[str, RTCDataChannel] = {} self.pubsub = pubsub self._listeners: list[Multiaddr] = [] self.peer_connection: RTCPeerConnection - config = {"iceServers": [...], "upgrader": Any} - self.ice_servers = self.config.get( - "iceServers", - [ - {"urls": "stun:stun.l.google.com:19302"}, - {"urls": "stun:stun1.l.google.com:19302"}, - ], - ) - + # config = {"iceServers": [...], "upgrader": Any} + # self.ice_servers = self.config.get( + # "iceServers", + # [ + # {"urls": "stun:stun.l.google.com:19302"}, + # {"urls": "stun:stun1.l.google.com:19302"}, + # ], + # ), + self.ice_servers = [ + RTCIceServer(urls="stun:stun.l.google.com:19302"), + RTCIceServer(urls="stun:stun1.l.google.com:19302"), + RTCIceServer(urls="stun:stun2.l.google.com:19302"), + RTCIceServer(urls="stun:stun3.l.google.com:19302"), + ] self.signal_service = SignalService(self.host) self.upgrader = self.config.get("upgrader") - - def _create_peer_connection(self) -> RTCPeerConnection: - return RTCPeerConnection( - configuration={"iceServers": self.ice_servers}, - certificates=[self.certificate], - ) + self.supported_protocols = { + "webrtc": P_WEBRTC, + "webrtc-direct": P_WEBRTC_DIRECT, + "certhash": P_CERTHASH, + } + + def _create_peer_connection(self, config) -> RTCPeerConnection: + if not config: + config = RTCPeerConnection(RTCConfiguration(iceServers=self.ice_servers)) + else: + return RTCPeerConnection(config) async def start(self) -> None: await self.start_peer_discovery() @@ -98,6 +116,11 @@ async def start(self) -> None: nursery.start_soon(self.handle_offer) logger.info("[WebRTC] WebRTCTransport started and listening for direct offers") + def can_handle(self, maddr: Multiaddr) -> bool: + """Check if transport can handle the multiaddr protocols""" + protocols = {p.name for p in maddr.protocols()} + return bool(protocols.intersection(self.supported_protocols.keys())) + async def start_peer_discovery(self) -> None: if not self.pubsub: gossipsub = GossipSub( @@ -122,12 +145,11 @@ async def handle_message() -> None: "webrtc-peer-discovery", str(self.peer_id).encode() ) - def verify_peer_certificate(self, remote_cert, expected_certhash: str): + def verify_peer_certificate(self, remote_cert, expected_certhash: str) -> bool: """ Compute the certhash of the remote certificate and compare to expected. """ - cert_mgr = CertificateManager() - actual_certhash = cert_mgr._compute_certhash(remote_cert) + actual_certhash = self.cert_mgr._compute_certhash(remote_cert) if actual_certhash != expected_certhash: raise ValueError( f"Certhash: expected {expected_certhash}, got {actual_certhash}" @@ -154,6 +176,34 @@ def on_message(message: Any) -> None: return channel + async def create_listener(self) -> WebRTCListener: + """ + Set up a WebRTC listener that waits for incoming data channels. + When a remote peer connects and opens a data channel, + wrap it in WebRTCRawConnection and pass to handler. + """ + pc = self._create_peer_connection(config=RTCConfiguration(iceServers=[])) + channel_ready = trio.Event() + + def on_datachannel(channel: RTCDataChannel): + logger.info("[WebRTC] Incoming data channel received") + WebRTCRawConnection(self.peer_id, channel) + + @channel.on("open") + async def on_open(): + logger.info("[WebRTC] Data channel opened by remote peer") + # handler_func(raw_conn) + channel_ready.set() + + pc.on("datachannel", on_datachannel) + + listener = WebRTCListener() + listener.set_host(self.host) + self.peer_connection = pc + + logger.info("[WebRTC] Listener created and waiting for incoming connections") + return listener + def relay_message(self, message: Any, exclude_peer: Optional[str] = None) -> None: """ Relay incoming message to all other connected peers, excluding the sender. @@ -183,9 +233,15 @@ async def _handle_signal_message(self, peer_id: str, data: dict[str, Any]): await self._handle_signal_ice(peer_id, data) async def _handle_signal_offer(self, peer_id: str, data: dict[str, Any]): - pc = self._create_peer_connection() + pc = self._create_peer_connection(config=None) self.peer_connection = pc + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await pc.setRemoteDescription(offer) + remote_cert = await pc.__certificates[0] + expected_certhash = data.get("certhash") + self.verify_peer_certificate(remote_cert, expected_certhash) + channel_ready = trio.Event() @pc.on("datachannel") @@ -193,28 +249,21 @@ def on_datachannel(channel): self.connected_peers[peer_id] = channel @channel.on("open") - def on_open(): - channel_ready.set() + async def on_open(): + await channel_ready.set() @channel.on("message") - def on_message(msg): - self.relay_message(msg, exclude_peer=peer_id) - - offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) - await pc.setRemoteDescription(offer) + async def on_message(msg): + await self.relay_message(msg, exclude_peer=peer_id) answer = await pc.createAnswer() await pc.setLocalDescription(answer) await self.signal_service.send_answer( peer_id, - { - "sdp": pc.localDescription.sdp, - "sdpType": pc.localDescription.type, - "certhash": CertificateManager()._compute_certhash( - self.certificate.x509 - ), - }, + sdp=pc.localDescription.sdp, + sdp_type=pc.localDescription.type, + certhash=self.certificate, ) await channel_ready.wait() @@ -222,6 +271,11 @@ async def _handle_signal_answer(self, peer_id: str, data: dict[str, Any]): answer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) await self.peer_connection.setRemoteDescription(answer) + cert_pem = self.cert_mgr.get_certificate_pem() + remote_cert = await generate_local_certhash(cert_pem=cert_pem) + expected_certhash = data.get("certhash") + self.verify_peer_certificate(remote_cert, expected_certhash) + async def _handle_signal_ice(self, peer_id: str, data: dict[str, Any]): candidate = RTCIceCandidate( component=data["component"], @@ -234,41 +288,83 @@ async def _handle_signal_ice(self, peer_id: str, data: dict[str, Any]): sdpMid=data["sdpMid"], ) await self.peer_connection.addIceCandidate(candidate) + await self.signal_service.send_ice_candidate( + peer_id=peer_id, candidate=candidate + ) async def handle_answer_from_peer(self, data: dict[str, Any]) -> None: answer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) await self.peer_connection.setRemoteDescription(answer) - async def handle_ice_candidate(self, data: dict[str, Any]) -> None: - candidate = RTCIceCandidate( - component=data["component"], - foundation=data["foundation"], - priority=data["priority"], - ip=data["ip"], - protocol=data["protocol"], - port=data["port"], - type=data["candidateType"], - sdpMid=data["sdpMid"], - ) - await self.peer_connection.addIceCandidate(candidate) + async def handle_offer(self): + logger.info("[signal] Listening for incoming offers via SignalService") + await self.signal_service.listen() - async def create_listener(self, handler_func: THandler) -> WebRTCListener: - def on_new_stream(stream): - handler_func(stream) + async def _on_offer(msg): + try: + data = json.loads(msg) + remote_peer_id = data["peer_id"] + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) - pc = self._create_peer_connection() - channel = await self.create_data_channel(pc, "webrtc-dial") - channel_ready = trio.Event() + pc = self._create_peer_connection(config=None) + logger.info( + f"[webrtc-direct] Received offer from peer {remote_peer_id}" + ) + channel_ready = trio.Event() - @channel.on("open") - def on_open(): - channel_ready.set() + @pc.on("datachannel") + def on_datachannel(channel): + logger.info( + f"[webrtc-direct] Datachannel received from {remote_peer_id}" + ) + self.connected_peers[remote_peer_id] = channel - raw_conn = WebRTCRawConnection(self.peer_id, channel) - raw_conn.on_stream(on_new_stream) - if not self.host: - raise RuntimeError("Host not initialized") - return WebRTCListener(handler=handler_func) + @channel.on("open") + async def on_open(): + logger.info( + f"[webrtc-direct] Channel open with {remote_peer_id}" + ) + await channel_ready.set() + + @channel.on("message") + async def on_message(msg): + logger.info(f"[Relay] Received from {remote_peer_id}: {msg}") + await self.relay_message(msg, exclude_peer=remote_peer_id) + + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await pc.setRemoteDescription(offer) + remote_cert = self.cert_mgr.generate_self_signed_cert("remote-cert") + expected_certhash = data.get("certhash") + self.verify_peer_certificate(remote_cert, expected_certhash) + self.verify_peer_id(remote_peer_id, str(self.peer_id)) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + response_topic = f"webrtc-answer-{remote_peer_id}" + await self.pubsub.publish( + response_topic, + json.dumps( + { + "peer_id": str(self.peer_id), + "sdp": pc.localDescription.sdp, + "sdpType": pc.localDescription.type, + "certhash": self.certificate, + } + ).encode(), + ) + logger.info(f"ans sent to peer {remote_peer_id} via {response_topic}") + await channel_ready.wait() + + except Exception as e: + logger.error(f"[webrtc-direct] Error handling offer: {e}") + + offer_topic = f"webrtc-offer-{remote_peer_id}" + logger.info(f"[webrtc-direct] Subscribing to topic: {offer_topic}") + topic = await self.pubsub.subscribe(offer_topic) + + async for msg in topic: + await _on_offer(msg) async def handle_incoming_candidates( self, stream: Any, peer_connection: RTCPeerConnection @@ -298,22 +394,19 @@ async def dial(self, maddr: Multiaddr) -> WebRTCRawConnection: _, peer_id, certhash = parse_webrtc_maddr(maddr) stream = await self.host.new_stream(peer_id, [SIGNAL_PROTOCOL]) - pc = self._create_peer_connection() + pc = self._create_peer_connection(config=None) channel = await self.create_data_channel(pc, "webrtc-dial") channel_ready = trio.Event() self.connected_peers[peer_id] = channel - # cert_pem= CertificateManager() - # cert: CertificateManager = cert_pem.generate_self_signed_cert() - # print(f"Certificate PEM: {cert}") @channel.on("open") - def on_open(): - channel_ready.set() + async def on_open(): + await channel_ready.set() @channel.on("message") - def on_message(msg): + async def on_message(msg): logger.info(f"[Relay] Received from {peer_id}: {msg}") - self.relay_message(msg, exclude_peer=peer_id) + await self.relay_message(msg, exclude_peer=peer_id) @pc.on("icecandidate") def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: @@ -340,13 +433,9 @@ def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: # await self.signal_service.send_offer(peer_id, offer) await self.signal_service.send_offer( peer_id, - { - "sdp": pc.localDescription.sdp, - "sdpType": pc.localDescription.type, - "certhash": CertificateManager()._compute_certhash( - self.certificate.x509 - ), - }, + sdp=pc.localDescription.sdp, + sdp_type=pc.localDescription.type, + certhash=self.certificate, ) except Exception as e: logger.error(f"[Signaling] Failed to send offer to {peer_id}: {e}") @@ -368,9 +457,7 @@ def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: "peer_id": self.peer_id, "sdp": offer.sdp, "sdpType": offer.type, - "certhash": CertificateManager()._compute_certhash( - self.certificate.x509 - ), + "certhash": self.certificate, } ).encode() ) @@ -397,37 +484,53 @@ def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: return logical_stream async def webrtc_direct_dial(self, maddr: Multiaddr) -> WebRTCRawConnection: - protocols = [p.name for p in maddr.protocols()] - if "webrtc-direct" in protocols: - logger.info("[Dial] Detected /webrtc-direct multiaddr....") - - ip, peer_id, certhash = parse_webrtc_maddr(maddr) - if not ip or not peer_id: - raise ValueError("Missing IP or Peer ID in webrtc-direct multiaddr") - logger.info( - f"Parsed IP={ip}, PeerID={peer_id}, Certhash={certhash or 'None'}" - ) + if isinstance(maddr, str): + maddr = Multiaddr(maddr) + + [p.name for p in maddr.protocols()] - pc = self._create_peer_connection() - channel = await self.create_data_channel( - pc, label="py-libp2p-webrtc-direct" + ip = maddr.value_for_protocol("ip4") + port = int(maddr.value_for_protocol("udp")) + peer_id = maddr.value_for_protocol("p2p") + + if not all([ip, port, peer_id]): + raise ValueError( + "Invalid WebRTC-direct multiaddr - missing required components" ) - channel_ready = trio.Event() - self.connected_peers[peer_id] = channel - @channel.on("open") - def on_open() -> None: - logger.info(f"[webrtc-direct] Channel open with {peer_id}") - channel_ready.set() + logger.info(f"Dialing WebRTC-direct peer at {ip}:{port} (ID: {peer_id})") - @channel.on("message") - def on_message(msg: Any) -> None: - logger.info(f"[Relay] Received from {peer_id}: {msg}") - self.relay_message(msg, exclude_peer=peer_id) + config = RTCConfiguration( + iceServers=[], # No STUN/TURN for direct + ) + + pc = self._create_peer_connection(config=config) + channel = await self.create_data_channel(pc, label="py-libp2p-webrtc-direct") + channel_ready = trio.Event() + self.connected_peers[peer_id] = channel + + @channel.on("open") + async def on_open() -> None: + logger.info(f"[webrtc-direct] Channel open with {peer_id}") + await channel_ready.set() + + @channel.on("message") + def on_message(msg: Any) -> None: + logger.info(f"[Relay] Received from {peer_id}: {msg}") + self.relay_message(msg, exclude_peer=peer_id) + + offer = await pc.createOffer() - offer = await anyio.from_thread.run_sync(pc.createOffer) - await anyio.from_thread.run_sync(pc.setLocalDescription, offer) + # Create and munge offer + munged_sdp = SDPMunger.munge_offer(offer.sdp, ip, port) + offer.sdp = munged_sdp + await pc.setLocalDescription(offer) + # await trio.to_thread.run_sync(lambda: pc.set_local_description(offer)) + logger.info(f"[webrtc-direct] Created offer for {peer_id} with munged SDP") + # offer = await anyio.from_thread.run_sync(pc.createOffer) + # await anyio.from_thread.run_sync(pc.setLocalDescription, offer) + try: if self.pubsub is None: await self.start_peer_discovery() await self.pubsub.publish( @@ -437,9 +540,7 @@ def on_message(msg: Any) -> None: "peer_id": self.peer_id, "sdp": offer.sdp, "sdpType": offer.type, - "certhash": CertificateManager()._compute_certhash( - self.certificate.x509 - ), + "certhash": self.certificate, } ).encode(), ) @@ -450,89 +551,30 @@ def on_message(msg: Any) -> None: async for msg in topic: answer_data = json.loads(msg.data.decode()) answer = RTCSessionDescription(**answer_data) - # await pc.setRemoteDescription(answer) - await anyio.from_thread.run_sync(pc.setRemoteDescription, answer) - break - - await channel_ready.wait() - raw_conn = WebRTCRawConnection(peer_id, channel) - if self.upgrader: - upgraded_conn = await self.upgrader.upgrade_connection(raw_conn) - return upgraded_conn - else: - return raw_conn - else: - logger.info("[Dial] Falling back to regular signal-based WebRTC") - return await self.dial(maddr) - - async def handle_offer(self): - logger.info("[signal] Listening for incoming offers via SignalService") - # await self.signal_service.listen() - - async def _on_offer(msg): - try: - data = json.loads(msg.data.decode()) - remote_peer_id = data["peer_id"] - offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) - - pc = self._create_peer_connection() - logger.info( - f"[webrtc-direct] Received offer from peer {remote_peer_id}" - ) - channel_ready = trio.Event() - - @pc.on("datachannel") - def on_datachannel(channel): - logger.info( - f"[webrtc-direct] Datachannel received from {remote_peer_id}" - ) - self.connected_peers[remote_peer_id] = channel - - @channel.on("open") - def on_open(): - logger.info( - f"[webrtc-direct] Channel open with {remote_peer_id}" - ) - channel_ready.set() - - @channel.on("message") - def on_message(msg): - logger.info(f"[Relay] Received from {remote_peer_id}: {msg}") - self.relay_message(msg, exclude_peer=remote_peer_id) - - offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) - await pc.setRemoteDescription(offer) - remote_cert = self.peer_connection.getRemoteCertificates()[0] - expected_certhash = data.get("certhash") - self.verify_peer_certificate(remote_cert, expected_certhash) - self.verify_peer_id(remote_peer_id, str(self.peer_id)) - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) + def set_remote_description(answer): + return pc.setRemoteDescription(answer) - response_topic = f"webrtc-answer-{remote_peer_id}" - await self.pubsub.publish( - response_topic, - json.dumps( - { - "peer_id": str(self.peer_id), - "sdp": pc.localDescription.sdp, - "sdpType": pc.localDescription.type, - "certhash": CertificateManager()._compute_certhash( - self.certificate.x509 - ), - } - ).encode(), - ) - logger.info(f"ans sent to peer {remote_peer_id} via {response_topic}") - await channel_ready.wait() + # await trio.to_thread.run_sync(lambda: set_remote_description(answer)) + # break + await pc.setRemoteDescription(answer) + await trio.to_thread.run_sync(pc.setRemoteDescription, answer) + break + except Exception as e: + logger.error(f"[webrtc-direct] Failed to publish offer via pubsub: {e}") + raise - except Exception as e: - logger.error(f"[webrtc-direct] Error handling offer: {e}") + # Wait for connection + with trio.move_on_after(30) as cancel_scope: + await channel_ready.wait() - offer_topic = f"webrtc-offer-{remote_peer_id}" - logger.info(f"[webrtc-direct] Subscribing to topic: {offer_topic}") - topic = await self.pubsub.subscribe(offer_topic) + if cancel_scope.cancelled_caught: + await pc.close() + raise ConnectionError("WebRTC connection timed out") - async for msg in topic: - await _on_offer(msg) + raw_conn = WebRTCRawConnection(peer_id, channel) + if self.upgrader: + upgraded_conn = await self.upgrader.upgrade_connection(raw_conn) + return upgraded_conn + else: + return raw_conn From b957b46eacba6b90a776f9b217f18fa796f2d0cd Mon Sep 17 00:00:00 2001 From: Neha Kumari Date: Tue, 3 Jun 2025 02:41:50 +0530 Subject: [PATCH 7/9] fix(ci): resolve WebRTC-listener test suite& build issues --- docs/libp2p.transport.webrtc.rst | 2 +- .../transport/webrtc/test_gen_certificate.py | 42 +++++----- libp2p/transport/webrtc/test_listener.py | 33 ++++++-- libp2p/transport/webrtc/test_signal.py | 29 +++---- .../webrtc/test_webrtc_direct_loopback.py | 77 ++++++++++++------- 5 files changed, 110 insertions(+), 73 deletions(-) diff --git a/docs/libp2p.transport.webrtc.rst b/docs/libp2p.transport.webrtc.rst index 1acbe085e..8d95fbf0a 100644 --- a/docs/libp2p.transport.webrtc.rst +++ b/docs/libp2p.transport.webrtc.rst @@ -66,4 +66,4 @@ Module contents .. automodule:: libp2p.transport.webrtc :members: :show-inheritance: - :undoc-members: \ No newline at end of file + :undoc-members: diff --git a/libp2p/transport/webrtc/test_gen_certificate.py b/libp2p/transport/webrtc/test_gen_certificate.py index 0ca52adc3..91f5e4d0b 100644 --- a/libp2p/transport/webrtc/test_gen_certificate.py +++ b/libp2p/transport/webrtc/test_gen_certificate.py @@ -1,50 +1,48 @@ -from .gen_certhash import ( - CertificateManager -) import pytest -import trio from cryptography.x509.oid import ( NameOID, ) +from .gen_certhash import ( + CertificateManager, +) # Certificate generation with default common name @pytest.mark.trio async def test_generate_self_signed_cert_with_default_common_name(): cert_manager = CertificateManager() - + cert_manager.generate_self_signed_cert() - + assert cert_manager.certificate is not None assert cert_manager.private_key is not None assert cert_manager.certhash is not None - + # Verify the common name is the default "py-libp2p" cert_subject = cert_manager.certificate.subject common_name = cert_subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value assert common_name == "py-libp2p" - + # Verify certhash format (base64 URL-safe encoded string) certhash = cert_manager.get_certhash() assert isinstance(certhash, str) - assert "+" not in certhash + assert "+" not in certhash assert "/" not in certhash - # Accessing certhash before certificate generation @pytest.mark.trio async def test_get_certhash_before_certificate_generation(): - cert_manager = CertificateManager() - - assert cert_manager.certificate is None - assert cert_manager.private_key is None - assert cert_manager.certhash is None - - certhash = cert_manager.get_certhash() - assert certhash == 'uEiNone' # Default value when no certificate is generated - - # generate the certificate and verify certhash is available - cert_manager.generate_self_signed_cert() - assert cert_manager.get_certhash() is not None \ No newline at end of file + cert_manager = CertificateManager() + + assert cert_manager.certificate is None + assert cert_manager.private_key is None + assert cert_manager.certhash is None + + certhash = cert_manager.get_certhash() + assert certhash == "uEiNone" # Default value when no certificate is generated + + # generate the certificate and verify certhash is available + cert_manager.generate_self_signed_cert() + assert cert_manager.get_certhash() is not None diff --git a/libp2p/transport/webrtc/test_listener.py b/libp2p/transport/webrtc/test_listener.py index 998a204b3..ea206f2a7 100644 --- a/libp2p/transport/webrtc/test_listener.py +++ b/libp2p/transport/webrtc/test_listener.py @@ -1,9 +1,16 @@ import pytest -import trio -from multiaddr import Multiaddr +from multiaddr import ( + Multiaddr, +) + +from libp2p.transport.webrtc.connection import ( + WebRTCRawConnection, +) +from libp2p.transport.webrtc.listener import ( + SIGNAL_PROTOCOL, + WebRTCListener, +) -from libp2p.transport.webrtc.listener import WebRTCListener, SIGNAL_PROTOCOL -from libp2p.transport.webrtc.connection import WebRTCRawConnection @pytest.mark.trio async def test_listen_and_accept_direct_connection(): @@ -17,6 +24,7 @@ def __init__(self): self._on_message = None self._on_open = None self.readyState = "open" + def on(self, event): def decorator(fn): if event == "message": @@ -24,7 +32,9 @@ def decorator(fn): elif event == "open": self._on_open = fn return fn + return decorator + def send(self, data): pass @@ -35,7 +45,9 @@ def send(self, data): await listener.conn_send_channel.send(conn) await listener.accept() - maddr = "/ip4/127.0.0.1/udp/9000/webrtc-direct/certhash/uEiqfMpAA6QOH0DT7YC5ggjBBG-c3CqLqbhQ4ovz4q6NyY/p2p/12D3KooWGiZiY9Vz2CCbbDkJATUrgZ6b6ov7f6AxZPkWR573V3Bx" + # maddr = "/ip4/127.0.0.1/udp/9000/webrtc-direct/ + # certhash/uEiqfMpAA6QOH0DT7YC5ggjBBG-c3CqLqbhQ4ovz4q6NyY/ + # p2p/12D3KooWGiZiY9Vz2CCbbDkJATUrgZ6b6ov7f6AxZPkWR573V3Bx" # result = await listener.listen(maddr) # assert result is True # addrs = listener.get_addrs() @@ -43,7 +55,7 @@ def send(self, data): # assert isinstance(addrs, tuple) # assert len(addrs) == 1 - #Accept the connection + # Accept the connection # accepted_conn = trio.move_on_after(1) # assert isinstance(accepted_conn, WebRTCRawConnection) @@ -56,23 +68,30 @@ class DummyNetwork: def __init__(self): self.listen_called_with = [] self._peer_id = b"dummy_peer_id" + async def listen(self, maddr): self.listen_called_with.append(maddr) return True + def get_peer_id(self): return self._peer_id + class DummyHost: def __init__(self): self.stream_handlers = {} self._network = DummyNetwork() + def set_stream_handler(self, protocol_id, handler): self.stream_handlers[protocol_id] = handler + def get_network(self): return self._network + def get_id(self): return self._network.get_peer_id() + @pytest.mark.trio async def test_listen_signaled_registers_stream_handler(): listener = WebRTCListener() @@ -87,4 +106,4 @@ async def test_listen_signaled_registers_stream_handler(): # Check that the listen_signaled returns True assert result is True # Check that the address is added to the listener's addrs - assert maddr in listener.get_addrs() \ No newline at end of file + assert maddr in listener.get_addrs() diff --git a/libp2p/transport/webrtc/test_signal.py b/libp2p/transport/webrtc/test_signal.py index c61fc737b..4fcc91421 100644 --- a/libp2p/transport/webrtc/test_signal.py +++ b/libp2p/transport/webrtc/test_signal.py @@ -1,26 +1,28 @@ -from unittest.mock import Mock import json -import trio -from wsgiref.types import InputStream +from unittest.mock import ( + Mock, +) + import pytest +import trio + from libp2p.abc import ( IHost, INetStream, - INotifee, TProtocol, ) from libp2p.peer.id import ( ID, ) + from .signal_service import ( - SignalService + SignalService, ) SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") class TestSignalService: - # Registering a handler for a specific message type and receiving that message type @pytest.mark.trio async def test_register_handler_and_receive_message(self): @@ -64,7 +66,6 @@ async def test_handler(msg, peer_id): assert received_peer_id == str(mock_peer_id) assert mock_stream.read.call_count == 2 - # Handling empty data received from a stream @pytest.mark.trio async def test_handle_empty_data(self): @@ -72,24 +73,24 @@ async def test_handle_empty_data(self): mock_stream = Mock(spec=INetStream) mock_muxed_conn = Mock() mock_peer_id = ID(b"test_peer_id") - + mock_muxed_conn.peer_id = mock_peer_id mock_stream.muxed_conn = mock_muxed_conn - + # Configure read to return empty data immediately mock_stream.read.return_value = b"" signal_service = SignalService(mock_host) - + handler_called = False - + async def test_handler(msg, peer_id): nonlocal handler_called handler_called = True - + signal_service.set_handler("test_type", test_handler) - + await signal_service.handle_signal(mock_stream) assert not handler_called - mock_stream.read.assert_called_once_with(4096) \ No newline at end of file + mock_stream.read.assert_called_once_with(4096) diff --git a/libp2p/transport/webrtc/test_webrtc_direct_loopback.py b/libp2p/transport/webrtc/test_webrtc_direct_loopback.py index 411555b7c..5c5037530 100644 --- a/libp2p/transport/webrtc/test_webrtc_direct_loopback.py +++ b/libp2p/transport/webrtc/test_webrtc_direct_loopback.py @@ -1,22 +1,35 @@ import logging -from multiaddr.protocols import add_protocol, Protocol, P_WEBRTC_DIRECT + import trio -from libp2p import new_host -from libp2p.host.basic_host import BasicHost -from libp2p.crypto.ed25519 import create_new_key_pair -from libp2p.peer.id import ID -from libp2p.pubsub.gossipsub import GossipSub -from libp2p.pubsub.pubsub import Pubsub -from .gen_certhash import CertificateManager, create_webrtc_direct_multiaddr -from .webrtc import WebRTCTransport -from multiaddr import Multiaddr -import anyio -import anyio.to_thread -from contextlib import asynccontextmanager + +from libp2p import ( + new_host, +) +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.pubsub.gossipsub import ( + GossipSub, +) +from libp2p.pubsub.pubsub import ( + Pubsub, +) + +from .gen_certhash import ( + CertificateManager, + create_webrtc_direct_multiaddr, +) +from .webrtc import ( + WebRTCTransport, +) logging.basicConfig(level=logging.INFO) logger = logging.getLogger("webrtc-direct-loopback-test") + async def build_host_and_transport(name: str): key_pair = create_new_key_pair() peer_id = ID.from_pubkey(key_pair.public_key) @@ -25,22 +38,26 @@ async def build_host_and_transport(name: str): host = new_host() pubsub = Pubsub( host, - GossipSub(protocols=["/libp2p/webrtc/signal/1.0.0"], degree=10, degree_low=3, degree_high=15), + GossipSub( + protocols=["/libp2p/webrtc/signal/1.0.0"], + degree=10, + degree_low=3, + degree_high=15, + ), None, ) webrtc_transport = WebRTCTransport(host, pubsub) return host, peer_id, webrtc_transport + async def run_webrtc_direct_loopback_test(): # Server (listener) host_b, peer_id_b, webrtc_transport_b = await build_host_and_transport("Server") cert_mgr_b = CertificateManager() cert_mgr_b.generate_self_signed_cert() - + maddr_b = create_webrtc_direct_multiaddr( - ip="127.0.0.1", - port=9000, - peer_id=peer_id_b + ip="127.0.0.1", port=9000, peer_id=peer_id_b ) logger.info(f"[B] Listening on: {maddr_b}") @@ -60,18 +77,18 @@ async def server_logic(): logger.info("[B] Waiting for incoming WebRTC-Direct connection...") raw_conn = await listener.accept() logger.info("[B] WebRTC-Direct connection accepted") - + # Testing bidirectional communication msg = await raw_conn.read() logger.info(f"[B] Received: {msg.decode()}") await raw_conn.write(b"Reply from B (webrtc-direct)") - + # Testing stream handling stream = await raw_conn.open_stream() await stream.write(b"Stream test from B") stream_data = await stream.read() logger.info(f"[B] Stream data received: {stream_data.decode()}") - + await raw_conn.close() logger.info("[B] Connection closed") except Exception as e: @@ -80,28 +97,28 @@ async def server_logic(): async def client_logic(): try: - await trio.sleep(1.0) + await trio.sleep(1.0) logger.info("[A] Dialing WebRTC-Direct...") - + with trio.move_on_after(60) as cancel_scope: # Add timeout - conn = await webrtc_transport_a.webrtc_direct_dial(maddr_b) - + conn = await webrtc_transport_a.webrtc_direct_dial(maddr_b) + if cancel_scope.cancelled_caught: logger.error("[A] Connection attempt timed out") return - + logger.info("[A] WebRTC-Direct connection established.") await conn.write(b"Hello from A (webrtc-direct)") reply = await conn.read() logger.info(f"[A] Received: {reply.decode()}") await conn.close() - + # Test stream handling # stream = await conn.open_stream() # await stream.write(b"Stream test from A") # stream_data = await stream.read() # logger.info(f"[A] Stream data received: {stream_data.decode()}") - + # await conn.close() logger.info("[A] Connection closed") except Exception as e: @@ -112,8 +129,10 @@ async def client_logic(): nursery.start_soon(server_logic) nursery.start_soon(client_logic) + async def run_main(): await run_webrtc_direct_loopback_test() + if __name__ == "__main__": - trio.run(run_main) \ No newline at end of file + trio.run(run_main) From 5ac3a6045199ecec755a0b0ede709f0c7bb42606 Mon Sep 17 00:00:00 2001 From: Neha Kumari Date: Tue, 3 Jun 2025 12:06:45 +0530 Subject: [PATCH 8/9] fix(docs):dependency config for builds --- .github/workflows/tox.yml | 1 + docs/libp2p.transport.webrtc.rst | 69 -------------------------------- setup.py | 1 + 3 files changed, 2 insertions(+), 69 deletions(-) delete mode 100644 docs/libp2p.transport.webrtc.rst diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 8fe058f69..efe6c9a36 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -59,6 +59,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install tox + pip install -e ".[docs]" - name: Test with tox shell: bash run: | diff --git a/docs/libp2p.transport.webrtc.rst b/docs/libp2p.transport.webrtc.rst deleted file mode 100644 index 8d95fbf0a..000000000 --- a/docs/libp2p.transport.webrtc.rst +++ /dev/null @@ -1,69 +0,0 @@ -libp2p.transport.webrtc package -=============================== - -Submodules ----------- - -libp2p.transport.webrtc.connection module ------------------------------------------ - -.. automodule:: libp2p.transport.webrtc.connection - :members: - :show-inheritance: - :undoc-members: - -libp2p.transport.webrtc.gen\_certhash module --------------------------------------------- - -.. automodule:: libp2p.transport.webrtc.gen_certhash - :members: - :show-inheritance: - :undoc-members: - -libp2p.transport.webrtc.listener module ---------------------------------------- - -.. automodule:: libp2p.transport.webrtc.listener - :members: - :show-inheritance: - :undoc-members: - -libp2p.transport.webrtc.signal\_service module ----------------------------------------------- - -.. automodule:: libp2p.transport.webrtc.signal_service - :members: - :show-inheritance: - :undoc-members: - -libp2p.transport.webrtc.test\_loopback module ---------------------------------------------- - -.. automodule:: libp2p.transport.webrtc.test_loopback - :members: - :show-inheritance: - :undoc-members: - -libp2p.transport.webrtc.test\_webrtc module -------------------------------------------- - -.. automodule:: libp2p.transport.webrtc.test_webrtc - :members: - :show-inheritance: - :undoc-members: - -libp2p.transport.webrtc.webrtc module -------------------------------------- - -.. automodule:: libp2p.transport.webrtc.webrtc - :members: - :show-inheritance: - :undoc-members: - -Module contents ---------------- - -.. automodule:: libp2p.transport.webrtc - :members: - :show-inheritance: - :undoc-members: diff --git a/setup.py b/setup.py index a23d811a9..0ec2e064a 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "sphinx>=6.0.0", "sphinx_rtd_theme>=1.0.0", "towncrier>=24,<25", + "aiortc>=1.5.0", ], "test": [ "p2pclient==0.2.0", From 444fd97178b87199c07c0a30c26691e8219cd6f3 Mon Sep 17 00:00:00 2001 From: Neha Kumari Date: Tue, 3 Jun 2025 12:10:16 +0530 Subject: [PATCH 9/9] fix(docs):rst config for builds --- docs/libp2p.transport.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/libp2p.transport.rst b/docs/libp2p.transport.rst index 9e96c7172..0d92c48f5 100644 --- a/docs/libp2p.transport.rst +++ b/docs/libp2p.transport.rst @@ -8,7 +8,6 @@ Subpackages :maxdepth: 4 libp2p.transport.tcp - libp2p.transport.webrtc Submodules ----------