Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install tox
pip install -e ".[docs]"
- name: Test with tox
shell: bash
run: |
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ repos:
language: system
always_run: true
pass_filenames: false
stages: [manual]
- repo: local
hooks:
- id: check-rst-files
Expand Down
Empty file.
80 changes: 80 additions & 0 deletions libp2p/transport/webrtc/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging
from typing import (
Any,
)

from aiortc import (
RTCDataChannel,
)
import trio
from trio import (
MemoryReceiveChannel,
MemorySendChannel,
)

from libp2p.abc import (
IRawConnection,
)
from libp2p.peer.id import (
ID,
)
from libp2p.stream_muxer.mplex.mplex import (
Mplex,
)

logger = logging.getLogger("webrtc")
logging.basicConfig(level=logging.INFO)


class WebRTCRawConnection(IRawConnection):
def __init__(self, peer_id: ID, channel: RTCDataChannel):
self.peer_id = peer_id
self.channel = channel
self.send_channel: MemorySendChannel[Any]
self.receive_channel: MemoryReceiveChannel[Any]
self.send_channel, self.receive_channel = trio.open_memory_channel(50)

@channel.on("message")
def on_message(message: Any) -> None:
self.send_channel.send_nowait(message)

self.mplex = Mplex(self, self.peer_id)

def _send_func(self, data: bytes) -> None:
self.channel.send(data)

async def _recv_func(self) -> bytes:
return await self.receive_channel.receive()

async def open_stream(self) -> Any:
return await self.mplex.open_stream()

async def accept_stream(self) -> Any:
return await self.mplex.accept_stream()

async def read(self, n: int = -1) -> bytes:
return await self.receive_channel.receive()

async def write(self, data: bytes) -> None:
self.channel.send(data)

def get_remote_address(self) -> tuple[str, int] | None:
return self.get_remote_address()

async def close(self) -> None:
self.channel.close()
await self.send_channel.aclose()
await self.receive_channel.aclose()
await self.mplex.close()

def get_local_peer(self) -> ID:
return self.get_local_peer()

def get_local_private_key(self) -> Any:
return self.get_local_private_key()

def get_remote_peer(self) -> ID:
return self.get_remote_peer()

def get_remote_public_key(self) -> Any:
return self.get_remote_public_key()
218 changes: 218 additions & 0 deletions libp2p/transport/webrtc/gen_certhash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import base64
import datetime
import hashlib
from typing import (
Optional,
)

from aiortc import (
RTCCertificate,
)
import base58
from cryptography import (
x509,
)
from cryptography.hazmat.backends import (
default_backend,
)
from cryptography.hazmat.primitives import (
hashes,
serialization,
)
from cryptography.hazmat.primitives.asymmetric import (
rsa,
)
from cryptography.x509.oid import (
NameOID,
)
from multiaddr import (
Multiaddr,
)

from libp2p.peer.id import (
ID,
)

SIGNAL_PROTOCOL = "/libp2p/webrtc/signal/1.0.0"


class CertificateManager(RTCCertificate):
def __init__(self):
self.x509 = None
self.private_key = None
self.certificate = None
self.certhash = None

def generate_self_signed_cert(self, common_name: str = "py-libp2p") -> None:
self.private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)
subject = issuer = x509.Name(
[x509.NameAttribute(NameOID.COMMON_NAME, common_name)]
)
self.certificate = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(self.private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
.sign(self.private_key, hashes.SHA256())
)
self.certhash = self._compute_certhash(self.certificate)

def _compute_certhash(self, cert: x509.Certificate) -> str:
# Encode in DER format and compute SHA-256 hash
der_bytes = cert.public_bytes(serialization.Encoding.DER)
sha256_hash = hashlib.sha256(der_bytes).digest()
return base64.urlsafe_b64encode(sha256_hash).decode("utf-8").rstrip("=")

def get_certhash(self) -> str:
# return self.certhash
return f"uEi{self.certhash}"

def get_certificate_pem(self) -> bytes:
return self.certificate.public_bytes(serialization.Encoding.PEM)

def get_private_key_pem(self) -> bytes:
return self.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)


class SDPMunger:
"""Handle SDP modification for direct connections"""

@staticmethod
def munge_offer(sdp: str, ip: str, port: int) -> str:
"""Modify SDP offer for direct connection"""
lines = sdp.split("\n")
munged = []

for line in lines:
if line.startswith("a=candidate"):
# Modify ICE candidate to use provided IP/port
parts = line.split()
parts[4] = ip
parts[5] = str(port)
line = " ".join(parts)
munged.append(line)

return "\n".join(munged)

@staticmethod
def munge_answer(sdp: str, ip: str, port: int) -> str:
"""Modify SDP answer for direct connection"""
return SDPMunger.munge_offer(sdp, ip, port)


def create_webrtc_multiaddr(
ip: str, peer_id: ID, certhash: str, direct: bool = False
) -> Multiaddr:
"""Create WebRTC multiaddr with proper format"""
# For direct connections
if direct:
return Multiaddr(
f"/ip4/{ip}/udp/0/webrtc-direct" f"/certhash/{certhash}" f"/p2p/{peer_id}"
)

# For signaled connections
return Multiaddr(f"/ip4/{ip}/webrtc" f"/certhash/{certhash}" f"/p2p/{peer_id}")
# return Multiaddr(f"/ip4/{ip}/webrtc/p2p/{peer_id}")


def verify_certhash(remote_cert: x509.Certificate, expected_hash: str) -> bool:
"""Verify remote certificate hash matches expected"""
der_bytes = remote_cert.public_bytes(serialization.Encoding.DER)
conv_hash = base64.urlsafe_b64encode(hashlib.sha256(der_bytes).digest())
actual_hash = f"uEi{conv_hash.decode('utf-8').rstrip('=')}"
return actual_hash == expected_hash


def create_webrtc_direct_multiaddr(ip: str, port: int, peer_id: ID) -> Multiaddr:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are here 2 functions for direct_multiaddr.

"""Create a WebRTC-direct multiaddr"""
# Format: /ip4/<ip>/udp/<port>/webrtc-direct/p2p/<peer_id>
return Multiaddr(f"/ip4/{ip}/udp/{port}/webrtc-direct/p2p/{peer_id}")


def parse_webrtc_maddr(maddr: Multiaddr) -> tuple[str, ID, str]:
"""
Parse a WebRTC multiaddr like:
/ip4/147.28.186.157/udp/9095/webrtc-direct/certhash/uEiDFVmAomKdAbivdrcIKdXGyuij_ax8b8at0GY_MJXMlwg/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit
/ip6/2604:1380:4642:6600::3/tcp/9095/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit/webrtc
/ip4/147.28.186.157/udp/9095/webrtc-direct/certhash/uEiDFVmAomKdAbivdrcIKdXGyuij_ax8b8at0GY_MJXMlwg/p2p/12D3KooWFhXabKDwALpzqMbto94sB7rvmZ6M28hs9Y9xSopDKwQr/p2p-circuit/webrtc
/ip4/127.0.0.1/udp/9000/webrtc-direct/certhash/uEia...1jI/p2p/12D3KooW...6HEh
Returns (ip, peer_id, certhash)
"""
try:
if isinstance(maddr, str):
maddr = Multiaddr(maddr)

parts = maddr.to_string().split("/")

# Get IP (after ip4 or ip6)
ip_idx = parts.index("ip4" if "ip4" in parts else "ip6") + 1
ip = parts[ip_idx]

# Get certhash (after certhash)
certhash_idx = parts.index("certhash") + 1
certhash = parts[certhash_idx]

# Get peer ID (after p2p)
peer_id_idx = parts.index("p2p") + 1
peer_id = parts[peer_id_idx]

if not all([ip, peer_id, certhash]):
raise ValueError("Missing required components in multiaddr")

return ip, peer_id, certhash

except Exception as e:
raise ValueError(f"Invalid WebRTC ma: {e}")


def generate_local_certhash(cert_pem: bytes) -> bytes:
cert = x509.load_pem_x509_certificate(cert_pem.encode(), default_backend())
der_bytes = cert.public_bytes(encoding=serialization.Encoding.DER)
digest = hashlib.sha256(der_bytes).digest()
certhash = base58.b58encode(digest).decode()
print(f"local_certhash= {certhash}")
return f"uEi{certhash}" # js-libp2p compatible


def generate_webrtc_multiaddr(
ip: str, peer_id: str, certhash: Optional[str] = None
) -> Multiaddr:
if not certhash:
raise ValueError("certhash must be provided for /webrtc-direct multiaddr")

cert_mgr = CertificateManager()
certhash = cert_mgr.get_certhash() if not certhash else certhash
if not isinstance(peer_id, ID):
peer_id = ID(peer_id)

base = f"/ip4/{ip}/udp/9000/webrtc-direct/certhash/{certhash}/p2p/{peer_id}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get port also as optional argument with default value 9000 instead of hardcoding


return Multiaddr(base)


def filter_addresses(addrs: list[Multiaddr]) -> list[Multiaddr]:
"""
Filters the given list of multiaddresses,
returning only those that are valid for WebRTC transport.

A valid WebRTC multiaddress typically contains /webrtc/ or /webrtc-direct/.
"""
valid_protocols = {"webrtc", "webrtc-direct"}

def is_valid_webrtc_addr(addr: Multiaddr) -> bool:
try:
protocols = [proto.name for proto in addr.protocols()]
return any(p in valid_protocols for p in protocols)
except Exception:
return False

return [addr for addr in addrs if is_valid_webrtc_addr(addr)]
Loading
Loading