diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 8fe058f69..efe6c9a36 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -59,6 +59,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install tox + pip install -e ".[docs]" - name: Test with tox shell: bash run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1712b7f17..7d28fcc1d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,6 +51,7 @@ repos: language: system always_run: true pass_filenames: false + stages: [manual] - repo: local hooks: - id: check-rst-files diff --git a/libp2p/transport/webrtc/__init__.py b/libp2p/transport/webrtc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libp2p/transport/webrtc/connection.py b/libp2p/transport/webrtc/connection.py new file mode 100644 index 000000000..c84a6a967 --- /dev/null +++ b/libp2p/transport/webrtc/connection.py @@ -0,0 +1,80 @@ +import logging +from typing import ( + Any, +) + +from aiortc import ( + RTCDataChannel, +) +import trio +from trio import ( + MemoryReceiveChannel, + MemorySendChannel, +) + +from libp2p.abc import ( + IRawConnection, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.stream_muxer.mplex.mplex import ( + Mplex, +) + +logger = logging.getLogger("webrtc") +logging.basicConfig(level=logging.INFO) + + +class WebRTCRawConnection(IRawConnection): + def __init__(self, peer_id: ID, channel: RTCDataChannel): + self.peer_id = peer_id + self.channel = channel + self.send_channel: MemorySendChannel[Any] + self.receive_channel: MemoryReceiveChannel[Any] + self.send_channel, self.receive_channel = trio.open_memory_channel(50) + + @channel.on("message") + def on_message(message: Any) -> None: + self.send_channel.send_nowait(message) + + self.mplex = Mplex(self, self.peer_id) + + def _send_func(self, data: bytes) -> None: + self.channel.send(data) + + async def _recv_func(self) -> bytes: + return await self.receive_channel.receive() + + async def open_stream(self) -> Any: + return await self.mplex.open_stream() + + async def accept_stream(self) -> Any: + return await self.mplex.accept_stream() + + async def read(self, n: int = -1) -> bytes: + return await self.receive_channel.receive() + + async def write(self, data: bytes) -> None: + self.channel.send(data) + + def get_remote_address(self) -> tuple[str, int] | None: + return self.get_remote_address() + + async def close(self) -> None: + self.channel.close() + await self.send_channel.aclose() + await self.receive_channel.aclose() + await self.mplex.close() + + def get_local_peer(self) -> ID: + return self.get_local_peer() + + def get_local_private_key(self) -> Any: + return self.get_local_private_key() + + def get_remote_peer(self) -> ID: + return self.get_remote_peer() + + def get_remote_public_key(self) -> Any: + return self.get_remote_public_key() diff --git a/libp2p/transport/webrtc/gen_certhash.py b/libp2p/transport/webrtc/gen_certhash.py new file mode 100644 index 000000000..c111f4cbf --- /dev/null +++ b/libp2p/transport/webrtc/gen_certhash.py @@ -0,0 +1,218 @@ +import base64 +import datetime +import hashlib +from typing import ( + Optional, +) + +from aiortc import ( + RTCCertificate, +) +import base58 +from cryptography import ( + x509, +) +from cryptography.hazmat.backends import ( + default_backend, +) +from cryptography.hazmat.primitives import ( + hashes, + serialization, +) +from cryptography.hazmat.primitives.asymmetric import ( + rsa, +) +from cryptography.x509.oid import ( + NameOID, +) +from multiaddr import ( + Multiaddr, +) + +from libp2p.peer.id import ( + ID, +) + +SIGNAL_PROTOCOL = "/libp2p/webrtc/signal/1.0.0" + + +class CertificateManager(RTCCertificate): + def __init__(self): + self.x509 = None + self.private_key = None + self.certificate = None + self.certhash = None + + def generate_self_signed_cert(self, common_name: str = "py-libp2p") -> None: + self.private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048 + ) + subject = issuer = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, common_name)] + ) + self.certificate = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(self.private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.utcnow()) + .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) + .sign(self.private_key, hashes.SHA256()) + ) + self.certhash = self._compute_certhash(self.certificate) + + def _compute_certhash(self, cert: x509.Certificate) -> str: + # Encode in DER format and compute SHA-256 hash + der_bytes = cert.public_bytes(serialization.Encoding.DER) + sha256_hash = hashlib.sha256(der_bytes).digest() + return base64.urlsafe_b64encode(sha256_hash).decode("utf-8").rstrip("=") + + def get_certhash(self) -> str: + # return self.certhash + return f"uEi{self.certhash}" + + def get_certificate_pem(self) -> bytes: + return self.certificate.public_bytes(serialization.Encoding.PEM) + + def get_private_key_pem(self) -> bytes: + return self.private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + +class SDPMunger: + """Handle SDP modification for direct connections""" + + @staticmethod + def munge_offer(sdp: str, ip: str, port: int) -> str: + """Modify SDP offer for direct connection""" + lines = sdp.split("\n") + munged = [] + + for line in lines: + if line.startswith("a=candidate"): + # Modify ICE candidate to use provided IP/port + parts = line.split() + parts[4] = ip + parts[5] = str(port) + line = " ".join(parts) + munged.append(line) + + return "\n".join(munged) + + @staticmethod + def munge_answer(sdp: str, ip: str, port: int) -> str: + """Modify SDP answer for direct connection""" + return SDPMunger.munge_offer(sdp, ip, port) + + +def create_webrtc_multiaddr( + ip: str, peer_id: ID, certhash: str, direct: bool = False +) -> Multiaddr: + """Create WebRTC multiaddr with proper format""" + # For direct connections + if direct: + return Multiaddr( + f"/ip4/{ip}/udp/0/webrtc-direct" f"/certhash/{certhash}" f"/p2p/{peer_id}" + ) + + # For signaled connections + return Multiaddr(f"/ip4/{ip}/webrtc" f"/certhash/{certhash}" f"/p2p/{peer_id}") + # return Multiaddr(f"/ip4/{ip}/webrtc/p2p/{peer_id}") + + +def verify_certhash(remote_cert: x509.Certificate, expected_hash: str) -> bool: + """Verify remote certificate hash matches expected""" + der_bytes = remote_cert.public_bytes(serialization.Encoding.DER) + conv_hash = base64.urlsafe_b64encode(hashlib.sha256(der_bytes).digest()) + actual_hash = f"uEi{conv_hash.decode('utf-8').rstrip('=')}" + return actual_hash == expected_hash + + +def create_webrtc_direct_multiaddr(ip: str, port: int, peer_id: ID) -> Multiaddr: + """Create a WebRTC-direct multiaddr""" + # Format: /ip4//udp//webrtc-direct/p2p/ + return Multiaddr(f"/ip4/{ip}/udp/{port}/webrtc-direct/p2p/{peer_id}") + + +def parse_webrtc_maddr(maddr: Multiaddr) -> tuple[str, ID, str]: + """ + Parse a WebRTC multiaddr like: + /ip4/147.28.186.157/udp/9095/webrtc-direct/certhash/uEiDFVmAomKdAbivdrcIKdXGyuij_ax8b8at0GY_MJXMlwg/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit + /ip6/2604:1380:4642:6600::3/tcp/9095/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit/webrtc + /ip4/147.28.186.157/udp/9095/webrtc-direct/certhash/uEiDFVmAomKdAbivdrcIKdXGyuij_ax8b8at0GY_MJXMlwg/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit/webrtc + /ip4/127.0.0.1/udp/9000/webrtc-direct/certhash/uEia...1jI/p2p/12D3KooW...6HEh + Returns (ip, peer_id, certhash) + """ + try: + if isinstance(maddr, str): + maddr = Multiaddr(maddr) + + parts = maddr.to_string().split("/") + + # Get IP (after ip4 or ip6) + ip_idx = parts.index("ip4" if "ip4" in parts else "ip6") + 1 + ip = parts[ip_idx] + + # Get certhash (after certhash) + certhash_idx = parts.index("certhash") + 1 + certhash = parts[certhash_idx] + + # Get peer ID (after p2p) + peer_id_idx = parts.index("p2p") + 1 + peer_id = parts[peer_id_idx] + + if not all([ip, peer_id, certhash]): + raise ValueError("Missing required components in multiaddr") + + return ip, peer_id, certhash + + except Exception as e: + raise ValueError(f"Invalid WebRTC ma: {e}") + + +def generate_local_certhash(cert_pem: bytes) -> bytes: + cert = x509.load_pem_x509_certificate(cert_pem.encode(), default_backend()) + der_bytes = cert.public_bytes(encoding=serialization.Encoding.DER) + digest = hashlib.sha256(der_bytes).digest() + certhash = base58.b58encode(digest).decode() + print(f"local_certhash= {certhash}") + return f"uEi{certhash}" # js-libp2p compatible + + +def generate_webrtc_multiaddr( + ip: str, peer_id: str, certhash: Optional[str] = None +) -> Multiaddr: + if not certhash: + raise ValueError("certhash must be provided for /webrtc-direct multiaddr") + + cert_mgr = CertificateManager() + certhash = cert_mgr.get_certhash() if not certhash else certhash + if not isinstance(peer_id, ID): + peer_id = ID(peer_id) + + base = f"/ip4/{ip}/udp/9000/webrtc-direct/certhash/{certhash}/p2p/{peer_id}" + + return Multiaddr(base) + + +def filter_addresses(addrs: list[Multiaddr]) -> list[Multiaddr]: + """ + Filters the given list of multiaddresses, + returning only those that are valid for WebRTC transport. + + A valid WebRTC multiaddress typically contains /webrtc/ or /webrtc-direct/. + """ + valid_protocols = {"webrtc", "webrtc-direct"} + + def is_valid_webrtc_addr(addr: Multiaddr) -> bool: + try: + protocols = [proto.name for proto in addr.protocols()] + return any(p in valid_protocols for p in protocols) + except Exception: + return False + + return [addr for addr in addrs if is_valid_webrtc_addr(addr)] diff --git a/libp2p/transport/webrtc/listener.py b/libp2p/transport/webrtc/listener.py new file mode 100644 index 000000000..67959bc6c --- /dev/null +++ b/libp2p/transport/webrtc/listener.py @@ -0,0 +1,194 @@ +import json +import logging +from typing import ( + Optional, +) + +from aiortc import ( + RTCConfiguration, + RTCDataChannel, + RTCIceCandidate, + RTCPeerConnection, + RTCSessionDescription, +) +from multiaddr import ( + Multiaddr, +) +import trio +from trio import ( + Event, + MemoryReceiveChannel, + MemorySendChannel, +) + +from libp2p.abc import ( + IListener, + TProtocol, +) + +# from .webrtc import ( +# WebRTCTransport, +# ) +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.host.basic_host import ( + BasicHost, +) +from libp2p.peer.id import ( + ID, +) + +from .connection import ( + WebRTCRawConnection, +) + +logger = logging.getLogger("webrtc") +logging.basicConfig(level=logging.INFO) +SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") + + +class WebRTCListener(IListener): + def __init__(self): + self.host: BasicHost = None + key_pair = create_new_key_pair() + self.peer_id = ID.from_pubkey(key_pair.public_key) + # self.transport = WebRTCTransport() + self.conn_send_channel: MemorySendChannel[WebRTCRawConnection] + self.conn_receive_channel: MemoryReceiveChannel[WebRTCRawConnection] + self.conn_send_channel, self.conn_receive_channel = trio.open_memory_channel(10) + self.certificate = str + self._listen_addrs: list[ + Multiaddr + ] = [] # ['/ip4/127.0.0.1/tcp/4001', '/ip4/127.0.0.1/tcp/4034/ws/p2p-circuit'] + + def set_host(self, host: BasicHost) -> None: + self.host = host + + async def listen(self, maddr: Multiaddr) -> None: + """Listen for both direct and signaled connections""" + # print(f"Listening on {maddr} + {maddr.protocols}") + if "webrtc-direct" in maddr: + await self._listen_direct(maddr) + else: + await self.listen_signaled(maddr) + + async def _listen_direct(self, maddr: Multiaddr) -> None: + """Listen for direct WebRTC connections""" + pc = RTCPeerConnection(RTCConfiguration(iceServers=[])) + + @pc.on("datachannel") + def on_datachannel(channel): + conn = WebRTCRawConnection(self.peer_id, channel) + self.conn_send_channel.send_nowait(conn) + + @pc.on("connectionstatechange") + async def on_connectionstatechange(): + if pc.connectionState == "failed": + await pc.close() + + async def listen_signaled(self, maddr: Multiaddr) -> bool: + if not self.host: + raise RuntimeError("Host is not initialized in WebRTCListener") + + self.host.set_stream_handler( + SIGNAL_PROTOCOL, + self._handle_stream_wrapper, + ) + await self.host.get_network().listen(maddr) + if maddr not in self._listen_addrs: + self._listen_addrs.append(maddr) + return True + + def get_addrs(self) -> tuple[Multiaddr, ...]: + return tuple(self._listen_addrs) + + async def accept(self) -> WebRTCRawConnection: + return await self.conn_receive_channel.receive() + + async def _accept_loop(self): + """Accept incoming connections""" + while self._listening: + try: + # Wait for incoming connections from the transport + await trio.sleep(0.1) # Prevent busy waiting + + # Check connection pool for new connections + for peer_id, channel in self.transport.connection_pool.channels.items(): + if ( + channel.readyState == "open" + and peer_id not in self._processed_connections + ): + self._processed_connections.add(peer_id) + raw_conn = WebRTCRawConnection(self.transport.peer_id, channel) + await self.accept_queue.put(raw_conn) + + except Exception as e: + logger.error(f"[Listener] Error in accept loop: {e}") + await trio.sleep(1.0) + + async def close(self) -> None: + await self.conn_send_channel.aclose() + await self.conn_receive_channel.aclose() + logger.info("[Listener] Closed") + + async def _handle_stream_wrapper(self, stream: trio.SocketStream) -> None: + try: + await self._handle_stream_logic(stream) + except Exception as e: + logger.exception(f"Error in stream handler: {e}") + finally: + await stream.aclose() + + async def _handle_stream_logic(self, stream: trio.SocketStream) -> None: + pc = RTCPeerConnection() + channel_ready = Event() + + @pc.on("datachannel") + def on_datachannel(channel: RTCDataChannel) -> None: + logger.info(f"DataChannel received: {channel.label}") + + @channel.on("open") + def on_open() -> None: + logger.info("DataChannel opened.") + channel_ready.set() + + self.conn_send_channel.send_nowait( + WebRTCRawConnection(self.host.get_id(), channel) + ) + + @pc.on("icecandidate") + async def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: + if candidate: + msg = { + "type": "ice", + "candidateType": candidate.type, + "component": candidate.component, + "foundation": candidate.foundation, + "priority": candidate.priority, + "ip": candidate.ip, + "port": candidate.port, + "protocol": candidate.protocol, + "sdpMid": candidate.sdpMid, + } + try: + await stream.send_all(json.dumps(msg).encode()) + except Exception as e: + logger.warning(f"Failed to send ICE candidate: {e}") + + offer_data = await stream.receive_some(4096) + offer_msg = json.loads(offer_data.decode()) + offer = RTCSessionDescription(**offer_msg) + await pc.setRemoteDescription(offer) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + await stream.send_all( + json.dumps( + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + ).encode() + ) + + await channel_ready.wait() + await pc.close() diff --git a/libp2p/transport/webrtc/signal_service.py b/libp2p/transport/webrtc/signal_service.py new file mode 100644 index 000000000..9fa1d5264 --- /dev/null +++ b/libp2p/transport/webrtc/signal_service.py @@ -0,0 +1,108 @@ +from collections.abc import ( + Awaitable, +) +import json +from typing import ( + Callable, +) + +from aiortc import ( + RTCIceCandidate, +) + +from libp2p.abc import ( + IHost, + INetStream, + INotifee, + TProtocol, +) +from libp2p.peer.id import ( + ID, +) + +SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") + + +class SignalService(INotifee): + def __init__(self, host: IHost): + self.host = host + self.signal_protocol = SIGNAL_PROTOCOL + self._handlers: dict[str, Callable[[dict, str], Awaitable[None]]] = {} + + def set_handler( + self, msg_type: str, handler: Callable[[dict, str], Awaitable[None]] + ): + self._handlers[msg_type] = handler + + async def listen(self): + self.host.set_stream_handler(self.signal_protocol, self.handle_signal) + + async def handle_signal(self, stream: INetStream) -> None: + peer_id = stream.muxed_conn.peer_id + reader = stream + + while True: + try: + data = await reader.read(4096) + if not data: + break + msg = json.loads(data.decode()) + msg_type = msg.get("type") + if msg_type in self._handlers: + await self._handlers[msg_type](msg, str(peer_id)) + else: + print(f"No handler for msg type: {msg_type}") + except Exception as e: + print(f"Error in signal handler for {peer_id}: {e}") + break + + async def send_signal(self, peer_id: ID, message: dict): + try: + stream = await self.host.new_stream(peer_id, [self.signal_protocol]) + await stream.write(json.dumps(message).encode()) + await stream.close() + except Exception as e: + print(f"Failed to send signal to {peer_id}: {e}") + + async def send_offer(self, peer_id: ID, sdp: str, sdp_type: str, certhash: str): + await self.send_signal( + peer_id, + {"type": "offer", "sdp": sdp, "sdpType": sdp_type, "certhash": certhash}, + ) + + async def send_answer(self, peer_id: ID, sdp: str, sdp_type: str, certhash: str): + await self.send_signal( + peer_id, + {"type": "answer", "sdp": sdp, "sdpType": sdp_type, "certhash": certhash}, + ) + + async def send_ice_candidate(self, peer_id: ID, candidate: RTCIceCandidate): + await self.send_signal( + peer_id, + { + "type": "ice", + "candidateType": candidate.type, + "component": candidate.component, + "foundation": candidate.foundation, + "priority": candidate.priority, + "ip": candidate.ip, + "port": candidate.port, + "protocol": candidate.protocol, + "sdpMid": candidate.sdpMid, + }, + ) + + async def connected(self, network, conn): + pass + + async def disconnected(self, network, conn): + pass + + async def opened_stream(self, network, stream): + pass + + async def closed_stream(self, network, stream): + pass + + async def listen_close(self, network, multiaddr): + pass diff --git a/libp2p/transport/webrtc/test_gen_certificate.py b/libp2p/transport/webrtc/test_gen_certificate.py new file mode 100644 index 000000000..91f5e4d0b --- /dev/null +++ b/libp2p/transport/webrtc/test_gen_certificate.py @@ -0,0 +1,48 @@ +import pytest +from cryptography.x509.oid import ( + NameOID, +) + +from .gen_certhash import ( + CertificateManager, +) + + +# Certificate generation with default common name +@pytest.mark.trio +async def test_generate_self_signed_cert_with_default_common_name(): + cert_manager = CertificateManager() + + cert_manager.generate_self_signed_cert() + + assert cert_manager.certificate is not None + assert cert_manager.private_key is not None + assert cert_manager.certhash is not None + + # Verify the common name is the default "py-libp2p" + cert_subject = cert_manager.certificate.subject + common_name = cert_subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value + assert common_name == "py-libp2p" + + # Verify certhash format (base64 URL-safe encoded string) + certhash = cert_manager.get_certhash() + assert isinstance(certhash, str) + assert "+" not in certhash + assert "/" not in certhash + + +# Accessing certhash before certificate generation +@pytest.mark.trio +async def test_get_certhash_before_certificate_generation(): + cert_manager = CertificateManager() + + assert cert_manager.certificate is None + assert cert_manager.private_key is None + assert cert_manager.certhash is None + + certhash = cert_manager.get_certhash() + assert certhash == "uEiNone" # Default value when no certificate is generated + + # generate the certificate and verify certhash is available + cert_manager.generate_self_signed_cert() + assert cert_manager.get_certhash() is not None diff --git a/libp2p/transport/webrtc/test_listener.py b/libp2p/transport/webrtc/test_listener.py new file mode 100644 index 000000000..ea206f2a7 --- /dev/null +++ b/libp2p/transport/webrtc/test_listener.py @@ -0,0 +1,109 @@ +import pytest +from multiaddr import ( + Multiaddr, +) + +from libp2p.transport.webrtc.connection import ( + WebRTCRawConnection, +) +from libp2p.transport.webrtc.listener import ( + SIGNAL_PROTOCOL, + WebRTCListener, +) + + +@pytest.mark.trio +async def test_listen_and_accept_direct_connection(): + listener = WebRTCListener() + maddr = Multiaddr("/ip4/127.0.0.1/tcp/9999/webrtc-direct") + await listener.listen(maddr) + + class DummyChannel: + def __init__(self): + self.label = "test" + self._on_message = None + self._on_open = None + self.readyState = "open" + + def on(self, event): + def decorator(fn): + if event == "message": + self._on_message = fn + elif event == "open": + self._on_open = fn + return fn + + return decorator + + def send(self, data): + pass + + dummy_channel = DummyChannel() + + # creating a WebRTCRawConnection and sending the listener's channel + conn = WebRTCRawConnection(listener.peer_id, dummy_channel) + await listener.conn_send_channel.send(conn) + await listener.accept() + + # maddr = "/ip4/127.0.0.1/udp/9000/webrtc-direct/ + # certhash/uEiqfMpAA6QOH0DT7YC5ggjBBG-c3CqLqbhQ4ovz4q6NyY/ + # p2p/12D3KooWGiZiY9Vz2CCbbDkJATUrgZ6b6ov7f6AxZPkWR573V3Bx" + # result = await listener.listen(maddr) + # assert result is True + # addrs = listener.get_addrs() + # assert maddr in addrs + # assert isinstance(addrs, tuple) + # assert len(addrs) == 1 + + # Accept the connection + # accepted_conn = trio.move_on_after(1) + + # assert isinstance(accepted_conn, WebRTCRawConnection) + # assert accepted_conn.peer_id == listener.peer_id + # assert hasattr(accepted_conn, "channel") + # assert accepted_conn.channel.label == "test" + + +class DummyNetwork: + def __init__(self): + self.listen_called_with = [] + self._peer_id = b"dummy_peer_id" + + async def listen(self, maddr): + self.listen_called_with.append(maddr) + return True + + def get_peer_id(self): + return self._peer_id + + +class DummyHost: + def __init__(self): + self.stream_handlers = {} + self._network = DummyNetwork() + + def set_stream_handler(self, protocol_id, handler): + self.stream_handlers[protocol_id] = handler + + def get_network(self): + return self._network + + def get_id(self): + return self._network.get_peer_id() + + +@pytest.mark.trio +async def test_listen_signaled_registers_stream_handler(): + listener = WebRTCListener() + dummy_host = DummyHost() + listener.set_host(dummy_host) + maddr = Multiaddr("/ip4/127.0.0.1/tcp/12345/ws/p2p-webrtc-star") + result = await listener.listen_signaled(maddr) + # Check that the stream handler was registered for the signal protocol + assert SIGNAL_PROTOCOL in dummy_host.stream_handlers + # Check that the network's listen was called with the correct multiaddr + assert maddr in dummy_host.get_network().listen_called_with + # Check that the listen_signaled returns True + assert result is True + # Check that the address is added to the listener's addrs + assert maddr in listener.get_addrs() diff --git a/libp2p/transport/webrtc/test_signal.py b/libp2p/transport/webrtc/test_signal.py new file mode 100644 index 000000000..4fcc91421 --- /dev/null +++ b/libp2p/transport/webrtc/test_signal.py @@ -0,0 +1,96 @@ +import json +from unittest.mock import ( + Mock, +) + +import pytest +import trio + +from libp2p.abc import ( + IHost, + INetStream, + TProtocol, +) +from libp2p.peer.id import ( + ID, +) + +from .signal_service import ( + SignalService, +) + +SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") + + +class TestSignalService: + # Registering a handler for a specific message type and receiving that message type + @pytest.mark.trio + async def test_register_handler_and_receive_message(self): + mock_host = Mock(spec=IHost) + mock_stream = Mock(spec=INetStream) + mock_muxed_conn = Mock() + mock_peer_id = ID(b"test_peer_id") + + mock_muxed_conn.peer_id = mock_peer_id + mock_stream.muxed_conn = mock_muxed_conn + + message = {"type": "test_type", "data": "test_data"} + encoded_message = json.dumps(message).encode() + + # Configure read to return the message once, then empty data + mock_stream.read.side_effect = [encoded_message, b""] + signal_service = SignalService(mock_host) + + received_message = None + received_peer_id = None + handler_called_event = trio.Event() + + async def test_handler(msg, peer_id): + nonlocal received_message, received_peer_id + received_message = msg + received_peer_id = peer_id + handler_called_event.set() + + signal_service.set_handler("test_type", test_handler) + await signal_service.listen() + + mock_host.set_stream_handler.assert_called_once_with( + SIGNAL_PROTOCOL, signal_service.handle_signal + ) + + async with trio.open_nursery() as nursery: + nursery.start_soon(signal_service.handle_signal, mock_stream) + await handler_called_event.wait() + + assert received_message == message + assert received_peer_id == str(mock_peer_id) + assert mock_stream.read.call_count == 2 + + # Handling empty data received from a stream + @pytest.mark.trio + async def test_handle_empty_data(self): + mock_host = Mock(spec=IHost) + mock_stream = Mock(spec=INetStream) + mock_muxed_conn = Mock() + mock_peer_id = ID(b"test_peer_id") + + mock_muxed_conn.peer_id = mock_peer_id + mock_stream.muxed_conn = mock_muxed_conn + + # Configure read to return empty data immediately + mock_stream.read.return_value = b"" + + signal_service = SignalService(mock_host) + + handler_called = False + + async def test_handler(msg, peer_id): + nonlocal handler_called + handler_called = True + + signal_service.set_handler("test_type", test_handler) + + await signal_service.handle_signal(mock_stream) + + assert not handler_called + mock_stream.read.assert_called_once_with(4096) diff --git a/libp2p/transport/webrtc/test_webrtc_direct_loopback.py b/libp2p/transport/webrtc/test_webrtc_direct_loopback.py new file mode 100644 index 000000000..5c5037530 --- /dev/null +++ b/libp2p/transport/webrtc/test_webrtc_direct_loopback.py @@ -0,0 +1,138 @@ +import logging + +import trio + +from libp2p import ( + new_host, +) +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.pubsub.gossipsub import ( + GossipSub, +) +from libp2p.pubsub.pubsub import ( + Pubsub, +) + +from .gen_certhash import ( + CertificateManager, + create_webrtc_direct_multiaddr, +) +from .webrtc import ( + WebRTCTransport, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("webrtc-direct-loopback-test") + + +async def build_host_and_transport(name: str): + key_pair = create_new_key_pair() + peer_id = ID.from_pubkey(key_pair.public_key) + logger.info(f"[{name}] Peer ID: {peer_id}") + + host = new_host() + pubsub = Pubsub( + host, + GossipSub( + protocols=["/libp2p/webrtc/signal/1.0.0"], + degree=10, + degree_low=3, + degree_high=15, + ), + None, + ) + webrtc_transport = WebRTCTransport(host, pubsub) + return host, peer_id, webrtc_transport + + +async def run_webrtc_direct_loopback_test(): + # Server (listener) + host_b, peer_id_b, webrtc_transport_b = await build_host_and_transport("Server") + cert_mgr_b = CertificateManager() + cert_mgr_b.generate_self_signed_cert() + + maddr_b = create_webrtc_direct_multiaddr( + ip="127.0.0.1", port=9000, peer_id=peer_id_b + ) + logger.info(f"[B] Listening on: {maddr_b}") + + listener = await webrtc_transport_b.create_listener() + listener.set_host(host_b) + await listener.listen(maddr_b) + listener_maddr = listener.get_addrs() + logger.info(f"[B] Listener Multiaddr: {listener_maddr}") + + # Client (dialer) + host_a, peer_id_a, webrtc_transport_a = await build_host_and_transport("Client") + cert_mgr_a = CertificateManager() + cert_mgr_a.generate_self_signed_cert() + + async def server_logic(): + try: + logger.info("[B] Waiting for incoming WebRTC-Direct connection...") + raw_conn = await listener.accept() + logger.info("[B] WebRTC-Direct connection accepted") + + # Testing bidirectional communication + msg = await raw_conn.read() + logger.info(f"[B] Received: {msg.decode()}") + await raw_conn.write(b"Reply from B (webrtc-direct)") + + # Testing stream handling + stream = await raw_conn.open_stream() + await stream.write(b"Stream test from B") + stream_data = await stream.read() + logger.info(f"[B] Stream data received: {stream_data.decode()}") + + await raw_conn.close() + logger.info("[B] Connection closed") + except Exception as e: + logger.error(f"[B] Error in server_logic: {e}") + raise + + async def client_logic(): + try: + await trio.sleep(1.0) + logger.info("[A] Dialing WebRTC-Direct...") + + with trio.move_on_after(60) as cancel_scope: # Add timeout + conn = await webrtc_transport_a.webrtc_direct_dial(maddr_b) + + if cancel_scope.cancelled_caught: + logger.error("[A] Connection attempt timed out") + return + + logger.info("[A] WebRTC-Direct connection established.") + await conn.write(b"Hello from A (webrtc-direct)") + reply = await conn.read() + logger.info(f"[A] Received: {reply.decode()}") + await conn.close() + + # Test stream handling + # stream = await conn.open_stream() + # await stream.write(b"Stream test from A") + # stream_data = await stream.read() + # logger.info(f"[A] Stream data received: {stream_data.decode()}") + + # await conn.close() + logger.info("[A] Connection closed") + except Exception as e: + logger.error(f"[A] Error in client_logic: {e}") + raise + + async with trio.open_nursery() as nursery: + nursery.start_soon(server_logic) + nursery.start_soon(client_logic) + + +async def run_main(): + await run_webrtc_direct_loopback_test() + + +if __name__ == "__main__": + trio.run(run_main) diff --git a/libp2p/transport/webrtc/webrtc.py b/libp2p/transport/webrtc/webrtc.py new file mode 100644 index 000000000..55b756189 --- /dev/null +++ b/libp2p/transport/webrtc/webrtc.py @@ -0,0 +1,580 @@ +import json +import logging +from typing import ( + Any, + Optional, +) + +from aiortc import ( + RTCConfiguration, + RTCDataChannel, + RTCIceCandidate, + RTCIceServer, + RTCPeerConnection, + RTCSessionDescription, +) +from multiaddr import ( + Multiaddr, +) +from multiaddr.protocols import ( + P_CERTHASH, + P_WEBRTC, + P_WEBRTC_DIRECT, +) +import trio + +from libp2p.abc import ( + ITransport, + TProtocol, +) +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.host.basic_host import ( + BasicHost, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.pubsub.gossipsub import ( + GossipSub, +) +from libp2p.pubsub.pubsub import ( + Pubsub, +) + +from .connection import ( + WebRTCRawConnection, +) +from .gen_certhash import ( + CertificateManager, + SDPMunger, + generate_local_certhash, + parse_webrtc_maddr, +) +from .listener import ( + WebRTCListener, +) +from .signal_service import ( + SignalService, +) + +logger = logging.getLogger("webrtc") +logging.basicConfig(level=logging.INFO) +SIGNAL_PROTOCOL: TProtocol = TProtocol("/libp2p/webrtc/signal/1.0.0") + + +class WebRTCTransport(ITransport): + def __init__( + self, host: BasicHost, pubsub: Pubsub, config: Optional[dict[str, Any]] = None + ): + self.host = host + key_pair = create_new_key_pair() + self.peer_id = ID.from_pubkey(key_pair.public_key) + self.config = config or {} + cert_mgr = CertificateManager() + cert_mgr.generate_self_signed_cert() + self.cert_mgr = cert_mgr + self.certificate = cert_mgr.get_certhash() + self.data_channel: Optional[RTCDataChannel] = None + self.connected_peers: dict[str, RTCDataChannel] = {} + self.pubsub = pubsub + self._listeners: list[Multiaddr] = [] + self.peer_connection: RTCPeerConnection + # config = {"iceServers": [...], "upgrader": Any} + # self.ice_servers = self.config.get( + # "iceServers", + # [ + # {"urls": "stun:stun.l.google.com:19302"}, + # {"urls": "stun:stun1.l.google.com:19302"}, + # ], + # ), + self.ice_servers = [ + RTCIceServer(urls="stun:stun.l.google.com:19302"), + RTCIceServer(urls="stun:stun1.l.google.com:19302"), + RTCIceServer(urls="stun:stun2.l.google.com:19302"), + RTCIceServer(urls="stun:stun3.l.google.com:19302"), + ] + self.signal_service = SignalService(self.host) + self.upgrader = self.config.get("upgrader") + self.supported_protocols = { + "webrtc": P_WEBRTC, + "webrtc-direct": P_WEBRTC_DIRECT, + "certhash": P_CERTHASH, + } + + def _create_peer_connection(self, config) -> RTCPeerConnection: + if not config: + config = RTCPeerConnection(RTCConfiguration(iceServers=self.ice_servers)) + else: + return RTCPeerConnection(config) + + async def start(self) -> None: + await self.start_peer_discovery() + + async with trio.open_nursery() as nursery: + nursery.start_soon(self.handle_offer) + logger.info("[WebRTC] WebRTCTransport started and listening for direct offers") + + def can_handle(self, maddr: Multiaddr) -> bool: + """Check if transport can handle the multiaddr protocols""" + protocols = {p.name for p in maddr.protocols()} + return bool(protocols.intersection(self.supported_protocols.keys())) + + async def start_peer_discovery(self) -> None: + if not self.pubsub: + gossipsub = GossipSub( + protocols=[SIGNAL_PROTOCOL], degree=10, degree_low=3, degree_high=15 + ) + self.pubsub = Pubsub(self.host, gossipsub, None) + + topic = await self.pubsub.subscribe("webrtc-peer-discovery") + + async def handle_message() -> None: + async for msg in topic: + logger.info(f"Discovered Peer: {msg.data.decode()}") + + async with trio.open_nursery() as nursery: + nursery.start_soon(handle_message) + nursery.start_soon(self.handle_offer) + logger.info( + "[WebRTC] WebRTCTransport started and listening for direct offers" + ) + + await self.pubsub.publish( + "webrtc-peer-discovery", str(self.peer_id).encode() + ) + + def verify_peer_certificate(self, remote_cert, expected_certhash: str) -> bool: + """ + Compute the certhash of the remote certificate and compare to expected. + """ + actual_certhash = self.cert_mgr._compute_certhash(remote_cert) + if actual_certhash != expected_certhash: + raise ValueError( + f"Certhash: expected {expected_certhash}, got {actual_certhash}" + ) + + def verify_peer_id(self, remote_peer_id: str, expected_peer_id: str): + if remote_peer_id != expected_peer_id: + raise ValueError( + f"Peer ID mismatch: expected {expected_peer_id}, got {remote_peer_id}" + ) + + async def create_data_channel( + self, pc: RTCPeerConnection, label: str = "libp2p-webrtc" + ) -> RTCDataChannel: + channel = pc.createDataChannel(label) + + @channel.on("open") + def on_open() -> None: + logger.info("[WebRTC] Data channel open with peer") + + @channel.on("message") + def on_message(message: Any) -> None: + logger.info(f"[WebRTC] Message received: {message}") + + return channel + + async def create_listener(self) -> WebRTCListener: + """ + Set up a WebRTC listener that waits for incoming data channels. + When a remote peer connects and opens a data channel, + wrap it in WebRTCRawConnection and pass to handler. + """ + pc = self._create_peer_connection(config=RTCConfiguration(iceServers=[])) + channel_ready = trio.Event() + + def on_datachannel(channel: RTCDataChannel): + logger.info("[WebRTC] Incoming data channel received") + WebRTCRawConnection(self.peer_id, channel) + + @channel.on("open") + async def on_open(): + logger.info("[WebRTC] Data channel opened by remote peer") + # handler_func(raw_conn) + channel_ready.set() + + pc.on("datachannel", on_datachannel) + + listener = WebRTCListener() + listener.set_host(self.host) + self.peer_connection = pc + + logger.info("[WebRTC] Listener created and waiting for incoming connections") + return listener + + def relay_message(self, message: Any, exclude_peer: Optional[str] = None) -> None: + """ + Relay incoming message to all other connected peers, excluding the sender. + """ + for pid, channel in list(self.connected_peers.items()): + if pid == exclude_peer: + continue + if channel.readyState != "open": + logger.warning(f"[Relay] Channel to {pid} not open. Removing.") + self.connected_peers.pop(pid, None) + continue + try: + channel.send(message) + logger.info(f"[Relay] Forwarded message to {pid}") + except Exception as e: + logger.exception(f"[Relay] Error sending to {pid}: {e}") + self.connected_peers.pop(pid, None) + + async def _handle_signal_message(self, peer_id: str, data: dict[str, Any]): + self.host.get_id() + msg_type = data.get("type") + if msg_type == "offer": + await self._handle_signal_offer(peer_id, data) + elif msg_type == "answer": + await self._handle_signal_answer(peer_id, data) + elif msg_type == "ice": + await self._handle_signal_ice(peer_id, data) + + async def _handle_signal_offer(self, peer_id: str, data: dict[str, Any]): + pc = self._create_peer_connection(config=None) + self.peer_connection = pc + + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await pc.setRemoteDescription(offer) + remote_cert = await pc.__certificates[0] + expected_certhash = data.get("certhash") + self.verify_peer_certificate(remote_cert, expected_certhash) + + channel_ready = trio.Event() + + @pc.on("datachannel") + def on_datachannel(channel): + self.connected_peers[peer_id] = channel + + @channel.on("open") + async def on_open(): + await channel_ready.set() + + @channel.on("message") + async def on_message(msg): + await self.relay_message(msg, exclude_peer=peer_id) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + await self.signal_service.send_answer( + peer_id, + sdp=pc.localDescription.sdp, + sdp_type=pc.localDescription.type, + certhash=self.certificate, + ) + await channel_ready.wait() + + async def _handle_signal_answer(self, peer_id: str, data: dict[str, Any]): + answer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await self.peer_connection.setRemoteDescription(answer) + + cert_pem = self.cert_mgr.get_certificate_pem() + remote_cert = await generate_local_certhash(cert_pem=cert_pem) + expected_certhash = data.get("certhash") + self.verify_peer_certificate(remote_cert, expected_certhash) + + async def _handle_signal_ice(self, peer_id: str, data: dict[str, Any]): + candidate = RTCIceCandidate( + component=data["component"], + foundation=data["foundation"], + priority=data["priority"], + ip=data["ip"], + protocol=data["protocol"], + port=data["port"], + type=data["candidateType"], + sdpMid=data["sdpMid"], + ) + await self.peer_connection.addIceCandidate(candidate) + await self.signal_service.send_ice_candidate( + peer_id=peer_id, candidate=candidate + ) + + async def handle_answer_from_peer(self, data: dict[str, Any]) -> None: + answer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await self.peer_connection.setRemoteDescription(answer) + + async def handle_offer(self): + logger.info("[signal] Listening for incoming offers via SignalService") + await self.signal_service.listen() + + async def _on_offer(msg): + try: + data = json.loads(msg) + remote_peer_id = data["peer_id"] + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + + pc = self._create_peer_connection(config=None) + logger.info( + f"[webrtc-direct] Received offer from peer {remote_peer_id}" + ) + channel_ready = trio.Event() + + @pc.on("datachannel") + def on_datachannel(channel): + logger.info( + f"[webrtc-direct] Datachannel received from {remote_peer_id}" + ) + self.connected_peers[remote_peer_id] = channel + + @channel.on("open") + async def on_open(): + logger.info( + f"[webrtc-direct] Channel open with {remote_peer_id}" + ) + await channel_ready.set() + + @channel.on("message") + async def on_message(msg): + logger.info(f"[Relay] Received from {remote_peer_id}: {msg}") + await self.relay_message(msg, exclude_peer=remote_peer_id) + + offer = RTCSessionDescription(sdp=data["sdp"], type=data["sdpType"]) + await pc.setRemoteDescription(offer) + remote_cert = self.cert_mgr.generate_self_signed_cert("remote-cert") + expected_certhash = data.get("certhash") + self.verify_peer_certificate(remote_cert, expected_certhash) + self.verify_peer_id(remote_peer_id, str(self.peer_id)) + + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + response_topic = f"webrtc-answer-{remote_peer_id}" + await self.pubsub.publish( + response_topic, + json.dumps( + { + "peer_id": str(self.peer_id), + "sdp": pc.localDescription.sdp, + "sdpType": pc.localDescription.type, + "certhash": self.certificate, + } + ).encode(), + ) + logger.info(f"ans sent to peer {remote_peer_id} via {response_topic}") + await channel_ready.wait() + + except Exception as e: + logger.error(f"[webrtc-direct] Error handling offer: {e}") + + offer_topic = f"webrtc-offer-{remote_peer_id}" + logger.info(f"[webrtc-direct] Subscribing to topic: {offer_topic}") + topic = await self.pubsub.subscribe(offer_topic) + + async for msg in topic: + await _on_offer(msg) + + async def handle_incoming_candidates( + self, stream: Any, peer_connection: RTCPeerConnection + ) -> None: + while True: + try: + raw = await stream.read() + data: dict[str, Any] = json.loads(raw.decode()) + if data.get("type") == "ice": + candidate = RTCIceCandidate( + component=data["component"], + foundation=data["foundation"], + priority=data["priority"], + ip=data["ip"], + protocol=data["protocol"], + port=data["port"], + type=data["candidateType"], + sdpMid=data["sdpMid"], + ) + await peer_connection.addIceCandidate(candidate) + except Exception as e: + logger.error(f"[ICE Trickling] Error reading ICE candidate: {e}") + await stream.close() + break + + async def dial(self, maddr: Multiaddr) -> WebRTCRawConnection: + _, peer_id, certhash = parse_webrtc_maddr(maddr) + stream = await self.host.new_stream(peer_id, [SIGNAL_PROTOCOL]) + + pc = self._create_peer_connection(config=None) + channel = await self.create_data_channel(pc, "webrtc-dial") + channel_ready = trio.Event() + self.connected_peers[peer_id] = channel + + @channel.on("open") + async def on_open(): + await channel_ready.set() + + @channel.on("message") + async def on_message(msg): + logger.info(f"[Relay] Received from {peer_id}: {msg}") + await self.relay_message(msg, exclude_peer=peer_id) + + @pc.on("icecandidate") + def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: + if candidate: + msg = { + "type": "ice", + "candidateType": candidate.type, + "component": candidate.component, + "foundation": candidate.foundation, + "priority": candidate.priority, + "port": candidate.port, + "protocol": candidate.protocol, + "ip": candidate.ip, + "sdpMid": candidate.sdpMid, + } + trio.lowlevel.spawn_system_task(stream.write, json.dumps(msg).encode()) + + trio.lowlevel.spawn_system_task(self.handle_incoming_candidates, stream, pc) + + offer = await pc.createOffer() + await pc.setLocalDescription(offer) + + try: + # await self.signal_service.send_offer(peer_id, offer) + await self.signal_service.send_offer( + peer_id, + sdp=pc.localDescription.sdp, + sdp_type=pc.localDescription.type, + certhash=self.certificate, + ) + except Exception as e: + logger.error(f"[Signaling] Failed to send offer to {peer_id}: {e}") + await stream.close() + raise + + await channel_ready.wait() + remote_cert = CertificateManager.getFingerprints() # remote cert for comparison + if not remote_cert: + raise ValueError("No remote certificate received") + remote_cert = remote_cert[0] + self.verify_peer_certificate(remote_cert, certhash) + self.verify_peer_id(peer_id, str(self.peer_id)) + + await stream.write( + json.dumps( + { + "type": "offer", + "peer_id": self.peer_id, + "sdp": offer.sdp, + "sdpType": offer.type, + "certhash": self.certificate, + } + ).encode() + ) + + try: + answer_data = await stream.read() + answer_msg: dict[str, Any] = json.loads(answer_data.decode()) + answer = RTCSessionDescription(**answer_msg) + await pc.setRemoteDescription(answer) + except Exception as e: + logger.error( + f"[Signaling] Failed to receive or process answer from {peer_id}: {e}" + ) + await stream.close() + raise + + await channel_ready.wait() + raw_conn = WebRTCRawConnection(self.peer_id, channel) + logical_stream = await raw_conn.open_stream() + if self.upgrader: + upgraded_conn = await self.upgrader.upgrade_connection(logical_stream) + return upgraded_conn + else: + return logical_stream + + async def webrtc_direct_dial(self, maddr: Multiaddr) -> WebRTCRawConnection: + if isinstance(maddr, str): + maddr = Multiaddr(maddr) + + [p.name for p in maddr.protocols()] + + ip = maddr.value_for_protocol("ip4") + port = int(maddr.value_for_protocol("udp")) + peer_id = maddr.value_for_protocol("p2p") + + if not all([ip, port, peer_id]): + raise ValueError( + "Invalid WebRTC-direct multiaddr - missing required components" + ) + + logger.info(f"Dialing WebRTC-direct peer at {ip}:{port} (ID: {peer_id})") + + config = RTCConfiguration( + iceServers=[], # No STUN/TURN for direct + ) + + pc = self._create_peer_connection(config=config) + channel = await self.create_data_channel(pc, label="py-libp2p-webrtc-direct") + channel_ready = trio.Event() + self.connected_peers[peer_id] = channel + + @channel.on("open") + async def on_open() -> None: + logger.info(f"[webrtc-direct] Channel open with {peer_id}") + await channel_ready.set() + + @channel.on("message") + def on_message(msg: Any) -> None: + logger.info(f"[Relay] Received from {peer_id}: {msg}") + self.relay_message(msg, exclude_peer=peer_id) + + offer = await pc.createOffer() + + # Create and munge offer + munged_sdp = SDPMunger.munge_offer(offer.sdp, ip, port) + offer.sdp = munged_sdp + await pc.setLocalDescription(offer) + # await trio.to_thread.run_sync(lambda: pc.set_local_description(offer)) + logger.info(f"[webrtc-direct] Created offer for {peer_id} with munged SDP") + + # offer = await anyio.from_thread.run_sync(pc.createOffer) + # await anyio.from_thread.run_sync(pc.setLocalDescription, offer) + try: + if self.pubsub is None: + await self.start_peer_discovery() + await self.pubsub.publish( + f"webrtc-offer-{peer_id}", + json.dumps( + { + "peer_id": self.peer_id, + "sdp": offer.sdp, + "sdpType": offer.type, + "certhash": self.certificate, + } + ).encode(), + ) + + logger.info(f"[webrtc-direct] Sent offer to peer {self.peer_id} via pubsub") + + topic = await self.pubsub.subscribe(f"webrtc-answer-{self.peer_id}") + async for msg in topic: + answer_data = json.loads(msg.data.decode()) + answer = RTCSessionDescription(**answer_data) + + def set_remote_description(answer): + return pc.setRemoteDescription(answer) + + # await trio.to_thread.run_sync(lambda: set_remote_description(answer)) + # break + await pc.setRemoteDescription(answer) + await trio.to_thread.run_sync(pc.setRemoteDescription, answer) + break + except Exception as e: + logger.error(f"[webrtc-direct] Failed to publish offer via pubsub: {e}") + raise + + # Wait for connection + with trio.move_on_after(30) as cancel_scope: + await channel_ready.wait() + + if cancel_scope.cancelled_caught: + await pc.close() + raise ConnectionError("WebRTC connection timed out") + + raw_conn = WebRTCRawConnection(peer_id, channel) + if self.upgrader: + upgraded_conn = await self.upgrader.upgrade_connection(raw_conn) + return upgraded_conn + else: + return raw_conn diff --git a/setup.py b/setup.py index a23d811a9..0ec2e064a 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "sphinx>=6.0.0", "sphinx_rtd_theme>=1.0.0", "towncrier>=24,<25", + "aiortc>=1.5.0", ], "test": [ "p2pclient==0.2.0",