Skip to content

Commit dabb3a0

Browse files
FIXME: Make TProtocol Optional[TProtocol] to keep types consistent (#770)
* FIXME: Make TProtocol Optional[TProtocol] to keep types consistent * correct test case of test_protocol_muxer * add newsfragment * unit test added --------- Co-authored-by: Manu Sheel Gupta <[email protected]>
1 parent e20a9a3 commit dabb3a0

File tree

9 files changed

+167
-17
lines changed

9 files changed

+167
-17
lines changed

libp2p/host/basic_host.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,13 @@ async def _swarm_stream_handler(self, net_stream: INetStream) -> None:
295295
)
296296
await net_stream.reset()
297297
return
298+
if protocol is None:
299+
logger.debug(
300+
"no protocol negotiated, closing stream from peer %s",
301+
net_stream.muxed_conn.peer_id,
302+
)
303+
await net_stream.reset()
304+
return
298305
net_stream.set_protocol(protocol)
299306
if handler is None:
300307
logger.debug(

libp2p/protocol_muxer/multiselect.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,11 @@ def add_handler(
4848
"""
4949
self.handlers[protocol] = handler
5050

51-
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
5251
async def negotiate(
5352
self,
5453
communicator: IMultiselectCommunicator,
5554
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
56-
) -> tuple[TProtocol, StreamHandlerFn | None]:
55+
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
5756
"""
5857
Negotiate performs protocol selection.
5958
@@ -84,14 +83,14 @@ async def negotiate(
8483
raise MultiselectError() from error
8584

8685
else:
87-
protocol = TProtocol(command)
88-
if protocol in self.handlers:
86+
protocol_to_check = None if not command else TProtocol(command)
87+
if protocol_to_check in self.handlers:
8988
try:
90-
await communicator.write(protocol)
89+
await communicator.write(command)
9190
except MultiselectCommunicatorError as error:
9291
raise MultiselectError() from error
9392

94-
return protocol, self.handlers[protocol]
93+
return protocol_to_check, self.handlers[protocol_to_check]
9594
try:
9695
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
9796
except MultiselectCommunicatorError as error:

libp2p/protocol_muxer/multiselect_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,10 @@ async def try_select(
134134
:raise MultiselectClientError: raised when protocol negotiation failed
135135
:return: selected protocol
136136
"""
137+
# Represent `None` protocol as an empty string.
138+
protocol_str = protocol if protocol is not None else ""
137139
try:
138-
await communicator.write(protocol)
140+
await communicator.write(protocol_str)
139141
except MultiselectCommunicatorError as error:
140142
raise MultiselectClientError() from error
141143

@@ -145,7 +147,7 @@ async def try_select(
145147
except MultiselectCommunicatorError as error:
146148
raise MultiselectClientError() from error
147149

148-
if response == protocol:
150+
if response == protocol_str:
149151
return protocol
150152
if response == PROTOCOL_NOT_FOUND_MSG:
151153
raise MultiselectClientError("protocol not supported")

libp2p/protocol_muxer/multiselect_communicator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ async def write(self, msg_str: str) -> None:
3030
"""
3131
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
3232
""" # noqa: E501
33-
msg_bytes = encode_delim(msg_str.encode())
33+
if msg_str is None:
34+
msg_bytes = encode_delim(b"")
35+
else:
36+
msg_bytes = encode_delim(msg_str.encode())
3437
try:
3538
await self.read_writer.write(msg_bytes)
3639
except IOException as error:

libp2p/security/security_multistream.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from libp2p.peer.id import (
1818
ID,
1919
)
20+
from libp2p.protocol_muxer.exceptions import (
21+
MultiselectError,
22+
)
2023
from libp2p.protocol_muxer.multiselect import (
2124
Multiselect,
2225
)
@@ -104,7 +107,7 @@ async def select_transport(
104107
:param is_initiator: true if we are the initiator, false otherwise
105108
:return: selected secure transport
106109
"""
107-
protocol: TProtocol
110+
protocol: TProtocol | None
108111
communicator = MultiselectCommunicator(conn)
109112
if is_initiator:
110113
# Select protocol if initiator
@@ -114,5 +117,7 @@ async def select_transport(
114117
else:
115118
# Select protocol if non-initiator
116119
protocol, _ = await self.multiselect.negotiate(communicator)
120+
if protocol is None:
121+
raise MultiselectError("fail to negotiate a security protocol")
117122
# Return transport from protocol
118123
return self.transports[protocol]

libp2p/stream_muxer/muxer_multistream.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from libp2p.peer.id import (
1818
ID,
1919
)
20+
from libp2p.protocol_muxer.exceptions import (
21+
MultiselectError,
22+
)
2023
from libp2p.protocol_muxer.multiselect import (
2124
Multiselect,
2225
)
@@ -73,14 +76,16 @@ async def select_transport(self, conn: IRawConnection) -> TMuxerClass:
7376
:param conn: conn to choose a transport over
7477
:return: selected muxer transport
7578
"""
76-
protocol: TProtocol
79+
protocol: TProtocol | None
7780
communicator = MultiselectCommunicator(conn)
7881
if conn.is_initiator:
7982
protocol = await self.multiselect_client.select_one_of(
8083
tuple(self.transports.keys()), communicator
8184
)
8285
else:
8386
protocol, _ = await self.multiselect.negotiate(communicator)
87+
if protocol is None:
88+
raise MultiselectError("fail to negotiate a stream muxer protocol")
8489
return self.transports[protocol]
8590

8691
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:

newsfragments/770.internal.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make TProtocol as Optional[TProtocol] to keep types consistent in py-libp2p/libp2p/protocol_muxer/multiselect.py

tests/core/protocol_muxer/test_negotiate_timeout.py

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
from collections import deque
2+
13
import pytest
24
import trio
35

4-
from libp2p.abc import (
5-
IMultiselectCommunicator,
6-
)
6+
from libp2p.abc import IMultiselectCommunicator, INetStream
77
from libp2p.custom_types import TProtocol
88
from libp2p.protocol_muxer.exceptions import (
99
MultiselectClientError,
@@ -13,6 +13,10 @@
1313
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
1414

1515

16+
async def dummy_handler(stream: INetStream) -> None:
17+
pass
18+
19+
1620
class DummyMultiselectCommunicator(IMultiselectCommunicator):
1721
"""
1822
Dummy MultiSelectCommunicator to test out negotiate timmeout.
@@ -31,7 +35,7 @@ async def read(self) -> str:
3135

3236

3337
@pytest.mark.trio
34-
async def test_select_one_of_timeout():
38+
async def test_select_one_of_timeout() -> None:
3539
ECHO = TProtocol("/echo/1.0.0")
3640
communicator = DummyMultiselectCommunicator()
3741

@@ -42,7 +46,7 @@ async def test_select_one_of_timeout():
4246

4347

4448
@pytest.mark.trio
45-
async def test_query_multistream_command_timeout():
49+
async def test_query_multistream_command_timeout() -> None:
4650
communicator = DummyMultiselectCommunicator()
4751
client = MultiselectClient()
4852

@@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout():
5155

5256

5357
@pytest.mark.trio
54-
async def test_negotiate_timeout():
58+
async def test_negotiate_timeout() -> None:
5559
communicator = DummyMultiselectCommunicator()
5660
server = Multiselect()
5761

5862
with pytest.raises(MultiselectError, match="handshake read timeout"):
5963
await server.negotiate(communicator, 2)
64+
65+
66+
class HandshakeThenHangCommunicator(IMultiselectCommunicator):
67+
handshaked: bool
68+
69+
def __init__(self) -> None:
70+
self.handshaked = False
71+
72+
async def write(self, msg_str: str) -> None:
73+
if msg_str == "/multistream/1.0.0":
74+
self.handshaked = True
75+
return
76+
77+
async def read(self) -> str:
78+
if not self.handshaked:
79+
return "/multistream/1.0.0"
80+
# After handshake, hang on read.
81+
await trio.sleep_forever()
82+
# Should not be reached.
83+
return ""
84+
85+
86+
@pytest.mark.trio
87+
async def test_negotiate_timeout_post_handshake() -> None:
88+
communicator = HandshakeThenHangCommunicator()
89+
server = Multiselect()
90+
with pytest.raises(MultiselectError, match="handshake read timeout"):
91+
await server.negotiate(communicator, 1)
92+
93+
94+
class MockCommunicator(IMultiselectCommunicator):
95+
def __init__(self, commands_to_read: list[str]):
96+
self.read_queue = deque(commands_to_read)
97+
self.written_data: list[str] = []
98+
99+
async def write(self, msg_str: str) -> None:
100+
self.written_data.append(msg_str)
101+
102+
async def read(self) -> str:
103+
if not self.read_queue:
104+
raise EOFError
105+
return self.read_queue.popleft()
106+
107+
108+
@pytest.mark.trio
109+
async def test_negotiate_empty_string_command() -> None:
110+
# server receives an empty string, which means client wants `None` protocol.
111+
server = Multiselect({None: dummy_handler})
112+
# Handshake, then empty command
113+
communicator = MockCommunicator(["/multistream/1.0.0", ""])
114+
protocol, handler = await server.negotiate(communicator)
115+
assert protocol is None
116+
assert handler == dummy_handler
117+
# Check that server sent back handshake and the protocol confirmation (empty string)
118+
assert communicator.written_data == ["/multistream/1.0.0", ""]
119+
120+
121+
@pytest.mark.trio
122+
async def test_negotiate_with_none_handler() -> None:
123+
# server has None handler, client sends "" to select it.
124+
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
125+
# Handshake, then empty command
126+
communicator = MockCommunicator(["/multistream/1.0.0", ""])
127+
protocol, handler = await server.negotiate(communicator)
128+
assert protocol is None
129+
assert handler == dummy_handler
130+
# Check written data: handshake, protocol confirmation
131+
assert communicator.written_data == ["/multistream/1.0.0", ""]
132+
133+
134+
@pytest.mark.trio
135+
async def test_negotiate_with_none_handler_ls() -> None:
136+
# server has None handler, client sends "ls" then empty string.
137+
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
138+
# Handshake, ls, empty command
139+
communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""])
140+
protocol, handler = await server.negotiate(communicator)
141+
assert protocol is None
142+
assert handler == dummy_handler
143+
# Check written data: handshake, ls response, protocol confirmation
144+
assert communicator.written_data[0] == "/multistream/1.0.0"
145+
assert "/proto1" in communicator.written_data[1]
146+
# Note: `ls` should not list the `None` protocol.
147+
assert "None" not in communicator.written_data[1]
148+
assert "\n\n" not in communicator.written_data[1]
149+
assert communicator.written_data[2] == ""

tests/core/protocol_muxer/test_protocol_muxer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,41 @@ async def dummy_handler(stream):
159159
protocols = ms.get_protocols()
160160

161161
assert set(protocols) == {p1, p2, p3}
162+
163+
164+
@pytest.mark.trio
165+
async def test_negotiate_optional_tprotocol(security_protocol):
166+
with pytest.raises(Exception):
167+
await perform_simple_test(
168+
None,
169+
[None],
170+
[None],
171+
security_protocol,
172+
)
173+
174+
175+
@pytest.mark.trio
176+
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
177+
security_protocol,
178+
):
179+
with pytest.raises(Exception):
180+
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)
181+
182+
183+
@pytest.mark.trio
184+
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
185+
expected_selected_protocol = PROTOCOL_ECHO
186+
await perform_simple_test(
187+
expected_selected_protocol,
188+
[None, PROTOCOL_ECHO],
189+
[PROTOCOL_ECHO],
190+
security_protocol,
191+
)
192+
193+
194+
@pytest.mark.trio
195+
async def test_negotiate_optional_tprotocol_server_none_client_other(
196+
security_protocol,
197+
):
198+
with pytest.raises(Exception):
199+
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)

0 commit comments

Comments
 (0)