Skip to content

Commit 5fcfc67

Browse files
fixme/correct-type (#746)
* fixme/correct-type * added newsfragment and test
1 parent dd14aad commit 5fcfc67

File tree

6 files changed

+48
-6
lines changed

6 files changed

+48
-6
lines changed

libp2p/abc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@
5050
Pubsub,
5151
)
5252

53+
from typing import TYPE_CHECKING
54+
55+
if TYPE_CHECKING:
56+
from libp2p.protocol_muxer.multiselect import Multiselect
57+
5358
from libp2p.pubsub.pb import (
5459
rpc_pb2,
5560
)
@@ -1545,9 +1550,8 @@ def get_network(self) -> INetworkService:
15451550
15461551
"""
15471552

1548-
# FIXME: Replace with correct return type
15491553
@abstractmethod
1550-
def get_mux(self) -> Any:
1554+
def get_mux(self) -> "Multiselect":
15511555
"""
15521556
Retrieve the muxer instance for the host.
15531557
@@ -2158,6 +2162,7 @@ def add_handler(self, protocol: TProtocol, handler: StreamHandlerFn) -> None:
21582162
21592163
"""
21602164

2165+
@abstractmethod
21612166
def get_protocols(self) -> tuple[TProtocol | None, ...]:
21622167
"""
21632168
Retrieve the protocols for which handlers have been registered.
@@ -2168,7 +2173,6 @@ def get_protocols(self) -> tuple[TProtocol | None, ...]:
21682173
A tuple of registered protocol names.
21692174
21702175
"""
2171-
return tuple(self.handlers.keys())
21722176

21732177
@abstractmethod
21742178
async def negotiate(

libp2p/identity/identify/identify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _mk_identify_protobuf(
5959
) -> Identify:
6060
public_key = host.get_public_key()
6161
laddrs = host.get_addrs()
62-
protocols = host.get_mux().get_protocols()
62+
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
6363

6464
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
6565
return Identify(

libp2p/protocol_muxer/multiselect.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ async def negotiate(
101101
except trio.TooSlowError:
102102
raise MultiselectError("handshake read timeout")
103103

104+
def get_protocols(self) -> tuple[TProtocol | None, ...]:
105+
"""
106+
Retrieve the protocols for which handlers have been registered.
107+
108+
Returns
109+
-------
110+
tuple[TProtocol, ...]
111+
A tuple of registered protocol names.
112+
113+
"""
114+
return tuple(self.handlers.keys())
115+
104116
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
105117
"""
106118
Perform handshake to agree on multiselect protocol.

libp2p/relay/circuit_v2/discovery.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,9 @@ async def _check_via_mux(self, peer_id: ID) -> bool | None:
292292
# Get protocols with proper typing
293293
mux_protocols = mux.get_protocols()
294294
if isinstance(mux_protocols, (list, tuple)):
295-
available_protocols = list(mux_protocols)
295+
available_protocols = [
296+
p for p in mux.get_protocols() if p is not None
297+
]
296298

297299
for protocol in available_protocols:
298300
try:
@@ -312,7 +314,7 @@ async def _check_via_mux(self, peer_id: ID) -> bool | None:
312314

313315
self._protocol_cache[peer_id] = peer_protocols
314316
protocol_str = str(PROTOCOL_ID)
315-
for protocol in peer_protocols:
317+
for protocol in map(TProtocol, peer_protocols):
316318
if protocol == protocol_str:
317319
return True
318320
return False

newsfragments/746.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Improved type safety in `get_mux()` and `get_protocols()` by returning properly typed values instead
2+
of `Any`. Also updated `identify.py` and `discovery.py` to handle `None` values safely and
3+
compare protocols correctly.

tests/core/protocol_muxer/test_protocol_muxer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from libp2p.custom_types import (
44
TProtocol,
55
)
6+
from libp2p.protocol_muxer.multiselect import Multiselect
67
from libp2p.tools.utils import (
78
create_echo_stream_handler,
89
)
@@ -138,3 +139,23 @@ async def test_multistream_command(security_protocol):
138139
# Dialer asks for unspoorted command
139140
with pytest.raises(ValueError, match="Command not supported"):
140141
await dialer.send_command(listener.get_id(), "random")
142+
143+
144+
@pytest.mark.trio
145+
async def test_get_protocols_returns_all_registered_protocols():
146+
ms = Multiselect()
147+
148+
async def dummy_handler(stream):
149+
pass
150+
151+
p1 = TProtocol("/echo/1.0.0")
152+
p2 = TProtocol("/foo/1.0.0")
153+
p3 = TProtocol("/bar/1.0.0")
154+
155+
ms.add_handler(p1, dummy_handler)
156+
ms.add_handler(p2, dummy_handler)
157+
ms.add_handler(p3, dummy_handler)
158+
159+
protocols = ms.get_protocols()
160+
161+
assert set(protocols) == {p1, p2, p3}

0 commit comments

Comments
 (0)