Skip to content

FIXME: Make TProtocol Optional[TProtocol] to keep types consistent #770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions libp2p/host/basic_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,13 @@ async def _swarm_stream_handler(self, net_stream: INetStream) -> None:
)
await net_stream.reset()
return
if protocol is None:
logger.debug(
"no protocol negotiated, closing stream from peer %s",
net_stream.muxed_conn.peer_id,
)
await net_stream.reset()
return
net_stream.set_protocol(protocol)
if handler is None:
logger.debug(
Expand Down
11 changes: 5 additions & 6 deletions libp2p/protocol_muxer/multiselect.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,11 @@ def add_handler(
"""
self.handlers[protocol] = handler

# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
async def negotiate(
self,
communicator: IMultiselectCommunicator,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> tuple[TProtocol, StreamHandlerFn | None]:
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
"""
Negotiate performs protocol selection.

Expand Down Expand Up @@ -84,14 +83,14 @@ async def negotiate(
raise MultiselectError() from error

else:
protocol = TProtocol(command)
if protocol in self.handlers:
protocol_to_check = None if not command else TProtocol(command)
if protocol_to_check in self.handlers:
try:
await communicator.write(protocol)
await communicator.write(command)
except MultiselectCommunicatorError as error:
raise MultiselectError() from error

return protocol, self.handlers[protocol]
return protocol_to_check, self.handlers[protocol_to_check]
try:
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
except MultiselectCommunicatorError as error:
Expand Down
6 changes: 4 additions & 2 deletions libp2p/protocol_muxer/multiselect_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,10 @@ async def try_select(
:raise MultiselectClientError: raised when protocol negotiation failed
:return: selected protocol
"""
# Represent `None` protocol as an empty string.
protocol_str = protocol if protocol is not None else ""
try:
await communicator.write(protocol)
await communicator.write(protocol_str)
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error

Expand All @@ -145,7 +147,7 @@ async def try_select(
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error

if response == protocol:
if response == protocol_str:
return protocol
if response == PROTOCOL_NOT_FOUND_MSG:
raise MultiselectClientError("protocol not supported")
Expand Down
5 changes: 4 additions & 1 deletion libp2p/protocol_muxer/multiselect_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ async def write(self, msg_str: str) -> None:
"""
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
""" # noqa: E501
msg_bytes = encode_delim(msg_str.encode())
if msg_str is None:
msg_bytes = encode_delim(b"")
else:
msg_bytes = encode_delim(msg_str.encode())
try:
await self.read_writer.write(msg_bytes)
except IOException as error:
Expand Down
7 changes: 6 additions & 1 deletion libp2p/security/security_multistream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
Multiselect,
)
Expand Down Expand Up @@ -104,7 +107,7 @@ async def select_transport(
:param is_initiator: true if we are the initiator, false otherwise
:return: selected secure transport
"""
protocol: TProtocol
protocol: TProtocol | None
communicator = MultiselectCommunicator(conn)
if is_initiator:
# Select protocol if initiator
Expand All @@ -114,5 +117,7 @@ async def select_transport(
else:
# Select protocol if non-initiator
protocol, _ = await self.multiselect.negotiate(communicator)
if protocol is None:
raise MultiselectError("fail to negotiate a security protocol")
# Return transport from protocol
return self.transports[protocol]
7 changes: 6 additions & 1 deletion libp2p/stream_muxer/muxer_multistream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
Multiselect,
)
Expand Down Expand Up @@ -73,14 +76,16 @@ async def select_transport(self, conn: IRawConnection) -> TMuxerClass:
:param conn: conn to choose a transport over
:return: selected muxer transport
"""
protocol: TProtocol
protocol: TProtocol | None
communicator = MultiselectCommunicator(conn)
if conn.is_initiator:
protocol = await self.multiselect_client.select_one_of(
tuple(self.transports.keys()), communicator
)
else:
protocol, _ = await self.multiselect.negotiate(communicator)
if protocol is None:
raise MultiselectError("fail to negotiate a stream muxer protocol")
return self.transports[protocol]

async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
Expand Down
38 changes: 38 additions & 0 deletions tests/core/protocol_muxer/test_protocol_muxer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,41 @@ async def dummy_handler(stream):
protocols = ms.get_protocols()

assert set(protocols) == {p1, p2, p3}


@pytest.mark.trio
async def test_negotiate_optional_tprotocol(security_protocol):
with pytest.raises(Exception):
await perform_simple_test(
None,
[None],
[None],
security_protocol,
)


@pytest.mark.trio
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
security_protocol,
):
with pytest.raises(Exception):
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)


@pytest.mark.trio
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
expected_selected_protocol = PROTOCOL_ECHO
await perform_simple_test(
expected_selected_protocol,
[None, PROTOCOL_ECHO],
[PROTOCOL_ECHO],
security_protocol,
)


@pytest.mark.trio
async def test_negotiate_optional_tprotocol_server_none_client_other(
security_protocol,
):
with pytest.raises(Exception):
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)