Skip to content

Commit 897389f

Browse files
authored
Merge pull request #1119 from pipermerriam/piper/issue-1092-fix-decoding-error-in-discovery
Handle malformed message during dao fork check
2 parents f01c656 + 32bb408 commit 897389f

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

p2p/exceptions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class MalformedMessage(BaseP2PError):
4242
"""
4343
Raised when a p2p command is received with a malformed message
4444
"""
45+
pass
4546

4647

4748
class UnknownProtocolCommand(BaseP2PError):

p2p/peer.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,18 @@ async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
355355
# too much time is being spent on this again, we need to consider running this in a
356356
# ProcessPoolExecutor(). Need to make sure we don't use all CPUs in the machine for that,
357357
# though, otherwise asyncio's event loop can't run and we can't keep up with other peers.
358-
decoded_msg = cast(Dict[str, Any], cmd.decode(msg))
359-
self.logger.trace("Successfully decoded %s msg: %s", cmd, decoded_msg)
360-
self.received_msgs[cmd] += 1
361-
return cmd, decoded_msg
358+
try:
359+
decoded_msg = cast(Dict[str, Any], cmd.decode(msg))
360+
except MalformedMessage as err:
361+
self.logger.debug(
362+
"Malformed message from peer %s: CMD:%s Error: %r",
363+
self, type(cmd).__name__, err,
364+
)
365+
raise
366+
else:
367+
self.logger.trace("Successfully decoded %s msg: %s", cmd, decoded_msg)
368+
self.received_msgs[cmd] += 1
369+
return cmd, decoded_msg
362370

363371
def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
364372
"""Handle the base protocol (P2P) messages."""
@@ -837,8 +845,8 @@ async def start_peer(self, peer: BasePeer) -> None:
837845
# check, we do it here because we want to perform it for incoming peer connections as
838846
# well.
839847
msgs = await self.ensure_same_side_on_dao_fork(peer)
840-
except DAOForkCheckFailure as e:
841-
self.logger.debug("DAO fork check with %s failed: %s", peer, e)
848+
except DAOForkCheckFailure as err:
849+
self.logger.debug("DAO fork check with %s failed: %s", peer, err)
842850
await peer.disconnect(DisconnectReason.useless_peer)
843851
return
844852
asyncio.ensure_future(peer.run(finished_callback=self._peer_finished))
@@ -962,9 +970,15 @@ async def ensure_same_side_on_dao_fork(
962970
else:
963971
msgs.append((cmd, msg))
964972
continue
965-
except (TimeoutError, PeerConnectionLost) as e:
973+
except (TimeoutError, PeerConnectionLost) as err:
974+
raise DAOForkCheckFailure(
975+
"Timed out waiting for DAO fork header from {}: {}".format(peer, err))
976+
except MalformedMessage as err:
966977
raise DAOForkCheckFailure(
967-
"Timed out waiting for DAO fork header from {}: {}".format(peer, e))
978+
"Malformed message while doing DAO fork check with {0}: {1}".format(
979+
peer, err,
980+
)
981+
) from err
968982

969983
try:
970984
request.validate_headers(headers)

p2p/protocol.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from eth.constants import NULL_BYTE
2323
from eth.rlp.headers import BlockHeader
2424

25-
from p2p.exceptions import ValidationError
25+
from p2p.exceptions import (
26+
MalformedMessage,
27+
ValidationError,
28+
)
2629
from p2p.utils import get_devp2p_cmd_id
2730

2831

@@ -87,7 +90,13 @@ def decode_payload(self, rlp_data: bytes) -> _DecodedMsgType:
8790
else:
8891
decoder = sedes.List(
8992
[type_ for _, type_ in self.structure], strict=self.decode_strict)
90-
data = rlp.decode(rlp_data, sedes=decoder)
93+
try:
94+
data = rlp.decode(rlp_data, sedes=decoder)
95+
except rlp.DecodingError as err:
96+
raise MalformedMessage(
97+
"Malformed %s message: %r".format(type(self).__name__, err)
98+
) from err
99+
91100
if isinstance(self.structure, sedes.CountableList):
92101
return data
93102
return {
@@ -99,7 +108,7 @@ def decode_payload(self, rlp_data: bytes) -> _DecodedMsgType:
99108
def decode(self, data: bytes) -> _DecodedMsgType:
100109
packet_type = get_devp2p_cmd_id(data)
101110
if packet_type != self.cmd_id:
102-
raise ValueError("Wrong packet type: {}".format(packet_type))
111+
raise MalformedMessage("Wrong packet type: {}".format(packet_type))
103112
return self.decode_payload(data[1:])
104113

105114
def encode(self, data: _DecodedMsgType) -> Tuple[bytes, bytes]:

0 commit comments

Comments
 (0)